simondh commited on
Commit
6f39808
·
1 Parent(s): b9ab46a
Files changed (4) hide show
  1. app.py +314 -159
  2. classifiers.py +72 -66
  3. prompts.py +1 -1
  4. utils.py +61 -48
app.py CHANGED
@@ -20,12 +20,13 @@ from prompts import (
20
  CATEGORY_SUGGESTION_PROMPT,
21
  ADDITIONAL_CATEGORY_PROMPT,
22
  VALIDATION_ANALYSIS_PROMPT,
23
- CATEGORY_IMPROVEMENT_PROMPT
24
  )
25
 
26
  # Configure logging
27
- logging.basicConfig(level=logging.INFO,
28
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
29
 
30
  # Initialize API key from environment variable
31
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
@@ -39,22 +40,23 @@ if OPENAI_API_KEY:
39
  except Exception as e:
40
  logging.error(f"Failed to initialize OpenAI client: {str(e)}")
41
 
 
42
  def update_api_key(api_key):
43
  """Update the OpenAI API key"""
44
  global OPENAI_API_KEY, client
45
-
46
  if not api_key:
47
  return "API Key cannot be empty"
48
-
49
  OPENAI_API_KEY = api_key
50
-
51
  try:
52
  client = OpenAI(api_key=api_key)
53
  # Test the connection with a simple request
54
  response = client.chat.completions.create(
55
  model="gpt-3.5-turbo",
56
  messages=[{"role": "user", "content": "test"}],
57
- max_tokens=5
58
  )
59
  return f"API Key updated and verified successfully"
60
  except Exception as e:
@@ -62,41 +64,45 @@ def update_api_key(api_key):
62
  logging.error(f"API key update failed: {error_msg}")
63
  return f"Failed to update API Key: {error_msg}"
64
 
 
65
  def process_file(file, text_columns, categories, classifier_type, show_explanations):
66
  """Process the uploaded file and classify text data"""
67
  # Initialize result_df and validation_report
68
  result_df = None
69
  validation_report = None
70
-
71
  try:
72
  # Load data from file
73
  if isinstance(file, str):
74
  df = load_data(file)
75
  else:
76
  df = load_data(file.name)
77
-
78
  if not text_columns:
79
  return None, "Please select at least one text column"
80
-
81
  # Check if all selected columns exist
82
  missing_columns = [col for col in text_columns if col not in df.columns]
83
  if missing_columns:
84
- return None, f"Columns not found in the file: {', '.join(missing_columns)}. Available columns: {', '.join(df.columns)}"
85
-
 
 
 
86
  # Combine text from selected columns
87
  texts = []
88
  for _, row in df.iterrows():
89
  combined_text = " ".join(str(row[col]) for col in text_columns)
90
  texts.append(combined_text)
91
-
92
  # Parse categories if provided
93
  category_list = []
94
  if categories:
95
  category_list = [cat.strip() for cat in categories.split(",")]
96
-
97
  # Select classifier based on data size and user choice
98
  num_texts = len(texts)
99
-
100
  # If no specific model is chosen, select the most appropriate one
101
  if classifier_type == "auto":
102
  if num_texts <= 500:
@@ -107,30 +113,36 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
107
  classifier_type = "hybrid"
108
  else:
109
  classifier_type = "tfidf"
110
-
111
  # Initialize appropriate classifier
112
  if classifier_type == "tfidf":
113
  classifier = TFIDFClassifier()
114
  results = classifier.classify(texts, category_list)
115
  elif classifier_type in ["gpt35", "gpt4"]:
116
  if client is None:
117
- return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'."
 
 
 
118
  model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
119
  classifier = LLMClassifier(client=client, model=model)
120
  results = classifier.classify(texts, category_list)
121
  else: # hybrid
122
  if client is None:
123
- return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'."
 
 
 
124
  # First pass with TF-IDF
125
  tfidf_classifier = TFIDFClassifier()
126
  tfidf_results = tfidf_classifier.classify(texts, category_list)
127
-
128
  # Second pass with LLM for low confidence results
129
  llm_classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
130
  results = []
131
  low_confidence_texts = []
132
  low_confidence_indices = []
133
-
134
  for i, (text, tfidf_result) in enumerate(zip(texts, tfidf_results)):
135
  if tfidf_result["confidence"] < 70: # If confidence is below 70%
136
  low_confidence_texts.append(text)
@@ -138,91 +150,97 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
138
  results.append(None) # Placeholder
139
  else:
140
  results.append(tfidf_result)
141
-
142
  if low_confidence_texts:
143
- llm_results = llm_classifier.classify(low_confidence_texts, category_list)
 
 
144
  for idx, llm_result in zip(low_confidence_indices, llm_results):
145
  results[idx] = llm_result
146
-
147
  # Create results dataframe
148
  result_df = df.copy()
149
  result_df["Category"] = [r["category"] for r in results]
150
  result_df["Confidence"] = [r["confidence"] for r in results]
151
-
152
  if show_explanations:
153
  result_df["Explanation"] = [r["explanation"] for r in results]
154
-
155
  # Validate results using LLM
156
  validation_report = validate_results(result_df, text_columns, client)
157
-
158
  return result_df, validation_report
159
-
160
  except Exception as e:
161
  error_traceback = traceback.format_exc()
162
  return None, f"Error: {str(e)}\n{error_traceback}"
163
 
 
164
  def export_results(df, format_type):
165
  """Export results to a file and return the file path for download"""
166
  if df is None:
167
  return None
168
-
169
  # Create a temporary file
170
  import tempfile
171
  import os
172
-
173
  # Create a temporary directory if it doesn't exist
174
  temp_dir = "temp_exports"
175
  os.makedirs(temp_dir, exist_ok=True)
176
-
177
  # Generate a unique filename
178
  timestamp = time.strftime("%Y%m%d-%H%M%S")
179
  filename = f"classification_results_{timestamp}"
180
-
181
  if format_type == "excel":
182
  file_path = os.path.join(temp_dir, f"{filename}.xlsx")
183
  df.to_excel(file_path, index=False)
184
  else:
185
  file_path = os.path.join(temp_dir, f"{filename}.csv")
186
  df.to_csv(file_path, index=False)
187
-
188
  return file_path
189
 
 
190
  # Create Gradio interface
191
  with gr.Blocks(title="Text Classification System") as demo:
192
  gr.Markdown("# Text Classification System")
193
  gr.Markdown("Upload your data file (Excel/CSV) and classify text using AI")
194
-
195
  with gr.Tab("Setup"):
196
  api_key_input = gr.Textbox(
197
  label="OpenAI API Key",
198
  placeholder="Enter your API key here",
199
  type="password",
200
- value=OPENAI_API_KEY
201
  )
202
  api_key_button = gr.Button("Update API Key")
203
  api_key_message = gr.Textbox(label="Status", interactive=False)
204
-
205
  # Display current API status
206
- api_status = "API Key is set" if OPENAI_API_KEY else "No API Key found. Please set one."
 
 
207
  gr.Markdown(f"**Current API Status**: {api_status}")
208
-
209
- api_key_button.click(update_api_key, inputs=[api_key_input], outputs=[api_key_message])
210
-
 
 
211
  with gr.Tab("Classify Data"):
212
  with gr.Column():
213
  file_input = gr.File(label="Upload Excel/CSV File")
214
-
215
  # Variable to store available columns
216
  available_columns = gr.State([])
217
-
218
  # Button to load file and suggest categories
219
  load_categories_button = gr.Button("Load File")
220
-
221
  # Display original dataframe
222
  original_df = gr.Dataframe(
223
- label="Original Data",
224
- interactive=False,
225
- visible=False
226
  )
227
 
228
  with gr.Row():
@@ -232,31 +250,29 @@ with gr.Blocks(title="Text Classification System") as demo:
232
  choices=[],
233
  value=[],
234
  interactive=True,
235
- visible=False
236
  )
237
 
238
  new_category = gr.Textbox(
239
  label="Add New Category",
240
  placeholder="Enter a new category name",
241
- visible=False
242
  )
243
  with gr.Row():
244
  add_category_button = gr.Button("Add Category", visible=False)
245
- suggest_category_button = gr.Button("Suggest Category", visible=False)
246
-
 
247
 
248
  # Original categories input (hidden)
249
- categories = gr.Textbox(
250
- visible=False
251
- )
252
-
253
-
254
  with gr.Column():
255
  text_column = gr.CheckboxGroup(
256
- label="Select Text Columns",
257
- choices=[],
258
  interactive=True,
259
- visible=False
260
  )
261
 
262
  classifier_type = gr.Dropdown(
@@ -264,18 +280,20 @@ with gr.Blocks(title="Text Classification System") as demo:
264
  ("TF-IDF (Rapide, <1000 lignes)", "tfidf"),
265
  ("LLM GPT-3.5 (Fiable, <1000 lignes)", "gpt35"),
266
  ("LLM GPT-4 (Très fiable, <500 lignes)", "gpt4"),
267
- ("TF-IDF + LLM (Hybride, >1000 lignes)", "hybrid")
268
  ],
269
  label="Modèle de classification",
270
  value="gpt35",
271
- visible=False
 
 
 
272
  )
273
- show_explanations = gr.Checkbox(label="Show Explanations", value=True, visible=False)
274
-
275
  process_button = gr.Button("Process and Classify", visible=False)
276
 
277
  results_df = gr.Dataframe(interactive=True, visible=False)
278
-
279
  # Create containers for visualization and validation report
280
  with gr.Row(visible=False) as results_row:
281
  with gr.Column():
@@ -284,161 +302,251 @@ with gr.Blocks(title="Text Classification System") as demo:
284
  csv_download = gr.File(label="Download CSV", visible=False)
285
  excel_download = gr.File(label="Download Excel", visible=False)
286
  with gr.Column():
287
- validation_output = gr.Textbox(label="Validation Report", interactive=False)
288
- improve_button = gr.Button("Improve Classification with Report", visible=False)
 
 
 
 
289
 
290
  # Function to load file and suggest categories
291
  def load_file_and_suggest_categories(file):
292
  if not file:
293
- return [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
294
  try:
295
  df = load_data(file.name)
296
  columns = list(df.columns)
297
-
298
  # Analyze columns to suggest text columns
299
  suggested_text_columns = []
300
  for col in columns:
301
  # Check if column contains text data
302
- if df[col].dtype == 'object': # String type
303
  # Check if column contains mostly text (not just numbers or dates)
304
  sample = df[col].head(100).dropna()
305
  if len(sample) > 0:
306
  # Check if most values contain spaces (indicating text)
307
- text_ratio = sum(' ' in str(val) for val in sample) / len(sample)
308
- if text_ratio > 0.3: # If more than 30% of values contain spaces
 
 
 
 
309
  suggested_text_columns.append(col)
310
-
311
  # If no columns were suggested, use all object columns
312
  if not suggested_text_columns:
313
- suggested_text_columns = [col for col in columns if df[col].dtype == 'object']
314
-
 
 
315
  # Get a sample of text for category suggestion
316
  sample_texts = []
317
  for col in suggested_text_columns:
318
  sample_texts.extend(df[col].head(5).tolist())
319
-
320
  # Use LLM to suggest categories
321
  if client:
322
- prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts[:5]))
 
 
323
  try:
324
  response = client.chat.completions.create(
325
  model="gpt-3.5-turbo",
326
  messages=[{"role": "user", "content": prompt}],
327
  temperature=0,
328
- max_tokens=100
329
  )
330
- suggested_cats = [cat.strip() for cat in response.choices[0].message.content.strip().split(",")]
 
 
 
 
 
331
  except:
332
- suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"]
 
 
 
 
 
 
333
  else:
334
- suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"]
335
-
 
 
 
 
 
 
336
  return (
337
- columns,
338
- gr.CheckboxGroup(choices=columns, value=suggested_text_columns),
339
- gr.CheckboxGroup(choices=suggested_cats, value=suggested_cats, visible=True),
 
 
340
  gr.Textbox(visible=True),
341
  gr.Button(visible=True),
342
  gr.Button(visible=True),
343
- gr.CheckboxGroup(choices=columns, value=suggested_text_columns, visible=True),
 
 
344
  gr.Dropdown(visible=True),
345
  gr.Checkbox(visible=True),
346
  gr.Button(visible=True),
347
- gr.Dataframe(value=df, visible=True)
348
  )
349
  except Exception as e:
350
- return [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False)
351
-
 
 
 
 
 
 
 
 
 
 
 
 
352
  # Function to add a new category
353
  def add_new_category(current_categories, new_category):
354
  if not new_category or new_category.strip() == "":
355
  return current_categories
356
  new_categories = current_categories + [new_category.strip()]
357
  return gr.CheckboxGroup(choices=new_categories, value=new_categories)
358
-
359
  # Function to update categories textbox
360
  def update_categories_textbox(selected_categories):
361
  return ", ".join(selected_categories)
362
-
363
  # Function to show results after processing
364
  def show_results(df, validation_report):
365
  """Show the results after processing"""
366
  if df is None:
367
- return gr.Row(visible=False), gr.File(visible=False), gr.File(visible=False), gr.Dataframe(visible=False)
368
-
 
 
 
 
 
369
  # Export to both formats
370
  csv_path = export_results(df, "csv")
371
  excel_path = export_results(df, "excel")
372
-
373
- return gr.Row(visible=True), gr.File(value=csv_path, visible=True), gr.File(value=excel_path, visible=True), gr.Dataframe(value=df, visible=True)
374
-
 
 
 
 
 
375
  # Function to suggest a new category
376
  def suggest_new_category(file, current_categories, text_columns):
377
  if not file or not text_columns:
378
- return gr.CheckboxGroup(choices=current_categories, value=current_categories)
379
-
 
 
380
  try:
381
  df = load_data(file.name)
382
-
383
  # Get sample texts from selected columns
384
  sample_texts = []
385
  for col in text_columns:
386
  sample_texts.extend(df[col].head(5).tolist())
387
-
388
  if client:
389
  prompt = ADDITIONAL_CATEGORY_PROMPT.format(
390
  existing_categories=", ".join(current_categories),
391
- sample_texts="\n---\n".join(sample_texts[:10])
392
  )
393
  try:
394
  response = client.chat.completions.create(
395
  model="gpt-3.5-turbo",
396
  messages=[{"role": "user", "content": prompt}],
397
  temperature=0,
398
- max_tokens=50
399
  )
400
  new_cat = response.choices[0].message.content.strip()
401
  if new_cat and new_cat not in current_categories:
402
  current_categories.append(new_cat)
403
  except:
404
  pass
405
-
406
- return gr.CheckboxGroup(choices=current_categories, value=current_categories)
 
 
407
  except Exception as e:
408
- return gr.CheckboxGroup(choices=current_categories, value=current_categories)
409
-
 
 
410
  # Function to handle export and show download button
411
  def handle_export(df, format_type):
412
  if df is None:
413
  return gr.File(visible=False)
414
  file_path = export_results(df, format_type)
415
  return gr.File(value=file_path, visible=True)
416
-
417
  # Function to improve classification based on validation report
418
- def improve_classification(df, validation_report, text_columns, categories, classifier_type, show_explanations, file):
 
 
 
 
 
 
 
 
419
  """Improve classification based on validation report"""
420
  if df is None or not validation_report:
421
- return df, validation_report, gr.Button(visible=False), gr.CheckboxGroup(choices=[], value=[])
422
-
 
 
 
 
 
423
  try:
424
  # Extract insights from validation report
425
  if client:
426
  prompt = VALIDATION_ANALYSIS_PROMPT.format(
427
  validation_report=validation_report,
428
- current_categories=categories
429
  )
430
  try:
431
  response = client.chat.completions.create(
432
  model="gpt-4",
433
  messages=[{"role": "user", "content": prompt}],
434
  temperature=0,
435
- max_tokens=300
 
 
 
436
  )
437
- improvements = json.loads(response.choices[0].message.content.strip())
438
-
439
  # Get current categories
440
- current_categories = [cat.strip() for cat in categories.split(",")]
441
-
 
 
442
  # If new categories are needed, suggest them based on the data
443
  if improvements.get("new_categories_needed", False):
444
  # Get sample texts for category suggestion
@@ -449,51 +557,84 @@ with gr.Blocks(title="Text Classification System") as demo:
449
  else:
450
  temp_df = load_data(file.name)
451
  sample_texts.extend(temp_df[col].head(10).tolist())
452
-
453
  category_prompt = CATEGORY_IMPROVEMENT_PROMPT.format(
454
  current_categories=", ".join(current_categories),
455
- analysis=improvements.get('analysis', ''),
456
- sample_texts="\n---\n".join(sample_texts[:10])
457
  )
458
-
459
  category_response = client.chat.completions.create(
460
  model="gpt-4",
461
  messages=[{"role": "user", "content": category_prompt}],
462
  temperature=0,
463
- max_tokens=100
464
  )
465
-
466
- new_categories = [cat.strip() for cat in category_response.choices[0].message.content.strip().split(",")]
 
 
 
 
 
467
  # Combine current and new categories
468
  all_categories = current_categories + new_categories
469
  categories = ",".join(all_categories)
470
-
471
  # Process with improved parameters
472
  improved_df, new_validation = process_file(
473
  file,
474
  text_columns,
475
  categories,
476
  classifier_type,
477
- show_explanations
 
 
 
 
 
 
 
 
 
478
  )
479
-
480
- return improved_df, new_validation, gr.Button(visible=True), gr.CheckboxGroup(choices=all_categories, value=all_categories)
481
  except Exception as e:
482
  print(f"Error in improvement process: {str(e)}")
483
- return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories)
 
 
 
 
 
 
 
484
  else:
485
- return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories)
 
 
 
 
 
 
 
486
  except Exception as e:
487
  print(f"Error in improvement process: {str(e)}")
488
- return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories)
489
-
 
 
 
 
 
 
 
490
  # Connect functions
491
  load_categories_button.click(
492
  load_file_and_suggest_categories,
493
  inputs=[file_input],
494
  outputs=[
495
- available_columns,
496
- text_column,
497
  suggested_categories,
498
  new_category,
499
  add_category_button,
@@ -502,74 +643,88 @@ with gr.Blocks(title="Text Classification System") as demo:
502
  classifier_type,
503
  show_explanations,
504
  process_button,
505
- original_df
506
- ]
507
  )
508
-
509
  add_category_button.click(
510
  add_new_category,
511
  inputs=[suggested_categories, new_category],
512
- outputs=[suggested_categories]
513
  )
514
-
515
  suggested_categories.change(
516
  update_categories_textbox,
517
  inputs=[suggested_categories],
518
- outputs=[categories]
519
  )
520
-
521
  suggest_category_button.click(
522
  suggest_new_category,
523
  inputs=[file_input, suggested_categories, text_column],
524
- outputs=[suggested_categories]
525
  )
526
-
527
  process_button.click(
528
- lambda: gr.Dataframe(visible=True),
529
- inputs=[],
530
- outputs=[results_df]
531
  ).then(
532
  process_file,
533
- inputs=[file_input, text_column, categories, classifier_type, show_explanations],
534
- outputs=[results_df, validation_output]
 
 
 
 
 
 
535
  ).then(
536
  show_results,
537
  inputs=[results_df, validation_output],
538
- outputs=[results_row, csv_download, excel_download, results_df]
539
  ).then(
540
- visualize_results,
541
- inputs=[results_df, text_column],
542
- outputs=[visualization]
543
  ).then(
544
- lambda x: gr.Button(visible=True),
545
- inputs=[],
546
- outputs=[improve_button]
547
  )
548
-
549
  improve_button.click(
550
  improve_classification,
551
- inputs=[results_df, validation_output, text_column, categories, classifier_type, show_explanations, file_input],
552
- outputs=[results_df, validation_output, improve_button, suggested_categories]
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  ).then(
554
  show_results,
555
  inputs=[results_df, validation_output],
556
- outputs=[results_row, csv_download, excel_download, results_df]
557
  ).then(
558
- visualize_results,
559
- inputs=[results_df, text_column],
560
- outputs=[visualization]
561
  )
562
 
 
563
  def create_example_data():
564
  """Create example data for demonstration"""
565
  from utils import create_example_file
 
566
  example_path = create_example_file()
567
  return f"Example file created at: {example_path}"
568
 
 
569
  if __name__ == "__main__":
570
  # Create examples directory and sample file if it doesn't exist
571
  if not os.path.exists("examples"):
572
  create_example_data()
573
-
574
  # Launch the Gradio app
575
  demo.launch()
 
20
  CATEGORY_SUGGESTION_PROMPT,
21
  ADDITIONAL_CATEGORY_PROMPT,
22
  VALIDATION_ANALYSIS_PROMPT,
23
+ CATEGORY_IMPROVEMENT_PROMPT,
24
  )
25
 
26
  # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
29
+ )
30
 
31
  # Initialize API key from environment variable
32
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
 
40
  except Exception as e:
41
  logging.error(f"Failed to initialize OpenAI client: {str(e)}")
42
 
43
+
44
  def update_api_key(api_key):
45
  """Update the OpenAI API key"""
46
  global OPENAI_API_KEY, client
47
+
48
  if not api_key:
49
  return "API Key cannot be empty"
50
+
51
  OPENAI_API_KEY = api_key
52
+
53
  try:
54
  client = OpenAI(api_key=api_key)
55
  # Test the connection with a simple request
56
  response = client.chat.completions.create(
57
  model="gpt-3.5-turbo",
58
  messages=[{"role": "user", "content": "test"}],
59
+ max_tokens=5,
60
  )
61
  return f"API Key updated and verified successfully"
62
  except Exception as e:
 
64
  logging.error(f"API key update failed: {error_msg}")
65
  return f"Failed to update API Key: {error_msg}"
66
 
67
+
68
  def process_file(file, text_columns, categories, classifier_type, show_explanations):
69
  """Process the uploaded file and classify text data"""
70
  # Initialize result_df and validation_report
71
  result_df = None
72
  validation_report = None
73
+
74
  try:
75
  # Load data from file
76
  if isinstance(file, str):
77
  df = load_data(file)
78
  else:
79
  df = load_data(file.name)
80
+
81
  if not text_columns:
82
  return None, "Please select at least one text column"
83
+
84
  # Check if all selected columns exist
85
  missing_columns = [col for col in text_columns if col not in df.columns]
86
  if missing_columns:
87
+ return (
88
+ None,
89
+ f"Columns not found in the file: {', '.join(missing_columns)}. Available columns: {', '.join(df.columns)}",
90
+ )
91
+
92
  # Combine text from selected columns
93
  texts = []
94
  for _, row in df.iterrows():
95
  combined_text = " ".join(str(row[col]) for col in text_columns)
96
  texts.append(combined_text)
97
+
98
  # Parse categories if provided
99
  category_list = []
100
  if categories:
101
  category_list = [cat.strip() for cat in categories.split(",")]
102
+
103
  # Select classifier based on data size and user choice
104
  num_texts = len(texts)
105
+
106
  # If no specific model is chosen, select the most appropriate one
107
  if classifier_type == "auto":
108
  if num_texts <= 500:
 
113
  classifier_type = "hybrid"
114
  else:
115
  classifier_type = "tfidf"
116
+
117
  # Initialize appropriate classifier
118
  if classifier_type == "tfidf":
119
  classifier = TFIDFClassifier()
120
  results = classifier.classify(texts, category_list)
121
  elif classifier_type in ["gpt35", "gpt4"]:
122
  if client is None:
123
+ return (
124
+ None,
125
+ "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'.",
126
+ )
127
  model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
128
  classifier = LLMClassifier(client=client, model=model)
129
  results = classifier.classify(texts, category_list)
130
  else: # hybrid
131
  if client is None:
132
+ return (
133
+ None,
134
+ "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'.",
135
+ )
136
  # First pass with TF-IDF
137
  tfidf_classifier = TFIDFClassifier()
138
  tfidf_results = tfidf_classifier.classify(texts, category_list)
139
+
140
  # Second pass with LLM for low confidence results
141
  llm_classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
142
  results = []
143
  low_confidence_texts = []
144
  low_confidence_indices = []
145
+
146
  for i, (text, tfidf_result) in enumerate(zip(texts, tfidf_results)):
147
  if tfidf_result["confidence"] < 70: # If confidence is below 70%
148
  low_confidence_texts.append(text)
 
150
  results.append(None) # Placeholder
151
  else:
152
  results.append(tfidf_result)
153
+
154
  if low_confidence_texts:
155
+ llm_results = llm_classifier.classify(
156
+ low_confidence_texts, category_list
157
+ )
158
  for idx, llm_result in zip(low_confidence_indices, llm_results):
159
  results[idx] = llm_result
160
+
161
  # Create results dataframe
162
  result_df = df.copy()
163
  result_df["Category"] = [r["category"] for r in results]
164
  result_df["Confidence"] = [r["confidence"] for r in results]
165
+
166
  if show_explanations:
167
  result_df["Explanation"] = [r["explanation"] for r in results]
168
+
169
  # Validate results using LLM
170
  validation_report = validate_results(result_df, text_columns, client)
171
+
172
  return result_df, validation_report
173
+
174
  except Exception as e:
175
  error_traceback = traceback.format_exc()
176
  return None, f"Error: {str(e)}\n{error_traceback}"
177
 
178
+
179
  def export_results(df, format_type):
180
  """Export results to a file and return the file path for download"""
181
  if df is None:
182
  return None
183
+
184
  # Create a temporary file
185
  import tempfile
186
  import os
187
+
188
  # Create a temporary directory if it doesn't exist
189
  temp_dir = "temp_exports"
190
  os.makedirs(temp_dir, exist_ok=True)
191
+
192
  # Generate a unique filename
193
  timestamp = time.strftime("%Y%m%d-%H%M%S")
194
  filename = f"classification_results_{timestamp}"
195
+
196
  if format_type == "excel":
197
  file_path = os.path.join(temp_dir, f"{filename}.xlsx")
198
  df.to_excel(file_path, index=False)
199
  else:
200
  file_path = os.path.join(temp_dir, f"{filename}.csv")
201
  df.to_csv(file_path, index=False)
202
+
203
  return file_path
204
 
205
+
206
  # Create Gradio interface
207
  with gr.Blocks(title="Text Classification System") as demo:
208
  gr.Markdown("# Text Classification System")
209
  gr.Markdown("Upload your data file (Excel/CSV) and classify text using AI")
210
+
211
  with gr.Tab("Setup"):
212
  api_key_input = gr.Textbox(
213
  label="OpenAI API Key",
214
  placeholder="Enter your API key here",
215
  type="password",
216
+ value=OPENAI_API_KEY,
217
  )
218
  api_key_button = gr.Button("Update API Key")
219
  api_key_message = gr.Textbox(label="Status", interactive=False)
220
+
221
  # Display current API status
222
+ api_status = (
223
+ "API Key is set" if OPENAI_API_KEY else "No API Key found. Please set one."
224
+ )
225
  gr.Markdown(f"**Current API Status**: {api_status}")
226
+
227
+ api_key_button.click(
228
+ update_api_key, inputs=[api_key_input], outputs=[api_key_message]
229
+ )
230
+
231
  with gr.Tab("Classify Data"):
232
  with gr.Column():
233
  file_input = gr.File(label="Upload Excel/CSV File")
234
+
235
  # Variable to store available columns
236
  available_columns = gr.State([])
237
+
238
  # Button to load file and suggest categories
239
  load_categories_button = gr.Button("Load File")
240
+
241
  # Display original dataframe
242
  original_df = gr.Dataframe(
243
+ label="Original Data", interactive=False, visible=False
 
 
244
  )
245
 
246
  with gr.Row():
 
250
  choices=[],
251
  value=[],
252
  interactive=True,
253
+ visible=False,
254
  )
255
 
256
  new_category = gr.Textbox(
257
  label="Add New Category",
258
  placeholder="Enter a new category name",
259
+ visible=False,
260
  )
261
  with gr.Row():
262
  add_category_button = gr.Button("Add Category", visible=False)
263
+ suggest_category_button = gr.Button(
264
+ "Suggest Category", visible=False
265
+ )
266
 
267
  # Original categories input (hidden)
268
+ categories = gr.Textbox(visible=False)
269
+
 
 
 
270
  with gr.Column():
271
  text_column = gr.CheckboxGroup(
272
+ label="Select Text Columns",
273
+ choices=[],
274
  interactive=True,
275
+ visible=False,
276
  )
277
 
278
  classifier_type = gr.Dropdown(
 
280
  ("TF-IDF (Rapide, <1000 lignes)", "tfidf"),
281
  ("LLM GPT-3.5 (Fiable, <1000 lignes)", "gpt35"),
282
  ("LLM GPT-4 (Très fiable, <500 lignes)", "gpt4"),
283
+ ("TF-IDF + LLM (Hybride, >1000 lignes)", "hybrid"),
284
  ],
285
  label="Modèle de classification",
286
  value="gpt35",
287
+ visible=False,
288
+ )
289
+ show_explanations = gr.Checkbox(
290
+ label="Show Explanations", value=True, visible=False
291
  )
292
+
 
293
  process_button = gr.Button("Process and Classify", visible=False)
294
 
295
  results_df = gr.Dataframe(interactive=True, visible=False)
296
+
297
  # Create containers for visualization and validation report
298
  with gr.Row(visible=False) as results_row:
299
  with gr.Column():
 
302
  csv_download = gr.File(label="Download CSV", visible=False)
303
  excel_download = gr.File(label="Download Excel", visible=False)
304
  with gr.Column():
305
+ validation_output = gr.Textbox(
306
+ label="Validation Report", interactive=False
307
+ )
308
+ improve_button = gr.Button(
309
+ "Improve Classification with Report", visible=False
310
+ )
311
 
312
  # Function to load file and suggest categories
313
  def load_file_and_suggest_categories(file):
314
  if not file:
315
+ return (
316
+ [],
317
+ gr.CheckboxGroup(choices=[]),
318
+ gr.CheckboxGroup(choices=[], visible=False),
319
+ gr.Textbox(visible=False),
320
+ gr.Button(visible=False),
321
+ gr.Button(visible=False),
322
+ gr.CheckboxGroup(choices=[], visible=False),
323
+ gr.Dropdown(visible=False),
324
+ gr.Checkbox(visible=False),
325
+ gr.Button(visible=False),
326
+ gr.Dataframe(visible=False),
327
+ )
328
  try:
329
  df = load_data(file.name)
330
  columns = list(df.columns)
331
+
332
  # Analyze columns to suggest text columns
333
  suggested_text_columns = []
334
  for col in columns:
335
  # Check if column contains text data
336
+ if df[col].dtype == "object": # String type
337
  # Check if column contains mostly text (not just numbers or dates)
338
  sample = df[col].head(100).dropna()
339
  if len(sample) > 0:
340
  # Check if most values contain spaces (indicating text)
341
+ text_ratio = sum(" " in str(val) for val in sample) / len(
342
+ sample
343
+ )
344
+ if (
345
+ text_ratio > 0.3
346
+ ): # If more than 30% of values contain spaces
347
  suggested_text_columns.append(col)
348
+
349
  # If no columns were suggested, use all object columns
350
  if not suggested_text_columns:
351
+ suggested_text_columns = [
352
+ col for col in columns if df[col].dtype == "object"
353
+ ]
354
+
355
  # Get a sample of text for category suggestion
356
  sample_texts = []
357
  for col in suggested_text_columns:
358
  sample_texts.extend(df[col].head(5).tolist())
359
+
360
  # Use LLM to suggest categories
361
  if client:
362
+ prompt = CATEGORY_SUGGESTION_PROMPT.format(
363
+ "\n---\n".join(sample_texts[:5])
364
+ )
365
  try:
366
  response = client.chat.completions.create(
367
  model="gpt-3.5-turbo",
368
  messages=[{"role": "user", "content": prompt}],
369
  temperature=0,
370
+ max_tokens=100,
371
  )
372
+ suggested_cats = [
373
+ cat.strip()
374
+ for cat in response.choices[0]
375
+ .message.content.strip()
376
+ .split(",")
377
+ ]
378
  except:
379
+ suggested_cats = [
380
+ "Positive",
381
+ "Negative",
382
+ "Neutral",
383
+ "Mixed",
384
+ "Other",
385
+ ]
386
  else:
387
+ suggested_cats = [
388
+ "Positive",
389
+ "Negative",
390
+ "Neutral",
391
+ "Mixed",
392
+ "Other",
393
+ ]
394
+
395
  return (
396
+ columns,
397
+ gr.CheckboxGroup(choices=columns, value=suggested_text_columns),
398
+ gr.CheckboxGroup(
399
+ choices=suggested_cats, value=suggested_cats, visible=True
400
+ ),
401
  gr.Textbox(visible=True),
402
  gr.Button(visible=True),
403
  gr.Button(visible=True),
404
+ gr.CheckboxGroup(
405
+ choices=columns, value=suggested_text_columns, visible=True
406
+ ),
407
  gr.Dropdown(visible=True),
408
  gr.Checkbox(visible=True),
409
  gr.Button(visible=True),
410
+ gr.Dataframe(value=df, visible=True),
411
  )
412
  except Exception as e:
413
+ return (
414
+ [],
415
+ gr.CheckboxGroup(choices=[]),
416
+ gr.CheckboxGroup(choices=[], visible=False),
417
+ gr.Textbox(visible=False),
418
+ gr.Button(visible=False),
419
+ gr.Button(visible=False),
420
+ gr.CheckboxGroup(choices=[], visible=False),
421
+ gr.Dropdown(visible=False),
422
+ gr.Checkbox(visible=False),
423
+ gr.Button(visible=False),
424
+ gr.Dataframe(visible=False),
425
+ )
426
+
427
  # Function to add a new category
428
  def add_new_category(current_categories, new_category):
429
  if not new_category or new_category.strip() == "":
430
  return current_categories
431
  new_categories = current_categories + [new_category.strip()]
432
  return gr.CheckboxGroup(choices=new_categories, value=new_categories)
433
+
434
  # Function to update categories textbox
435
  def update_categories_textbox(selected_categories):
436
  return ", ".join(selected_categories)
437
+
438
  # Function to show results after processing
439
  def show_results(df, validation_report):
440
  """Show the results after processing"""
441
  if df is None:
442
+ return (
443
+ gr.Row(visible=False),
444
+ gr.File(visible=False),
445
+ gr.File(visible=False),
446
+ gr.Dataframe(visible=False),
447
+ )
448
+
449
  # Export to both formats
450
  csv_path = export_results(df, "csv")
451
  excel_path = export_results(df, "excel")
452
+
453
+ return (
454
+ gr.Row(visible=True),
455
+ gr.File(value=csv_path, visible=True),
456
+ gr.File(value=excel_path, visible=True),
457
+ gr.Dataframe(value=df, visible=True),
458
+ )
459
+
460
  # Function to suggest a new category
461
  def suggest_new_category(file, current_categories, text_columns):
462
  if not file or not text_columns:
463
+ return gr.CheckboxGroup(
464
+ choices=current_categories, value=current_categories
465
+ )
466
+
467
  try:
468
  df = load_data(file.name)
469
+
470
  # Get sample texts from selected columns
471
  sample_texts = []
472
  for col in text_columns:
473
  sample_texts.extend(df[col].head(5).tolist())
474
+
475
  if client:
476
  prompt = ADDITIONAL_CATEGORY_PROMPT.format(
477
  existing_categories=", ".join(current_categories),
478
+ sample_texts="\n---\n".join(sample_texts[:10]),
479
  )
480
  try:
481
  response = client.chat.completions.create(
482
  model="gpt-3.5-turbo",
483
  messages=[{"role": "user", "content": prompt}],
484
  temperature=0,
485
+ max_tokens=50,
486
  )
487
  new_cat = response.choices[0].message.content.strip()
488
  if new_cat and new_cat not in current_categories:
489
  current_categories.append(new_cat)
490
  except:
491
  pass
492
+
493
+ return gr.CheckboxGroup(
494
+ choices=current_categories, value=current_categories
495
+ )
496
  except Exception as e:
497
+ return gr.CheckboxGroup(
498
+ choices=current_categories, value=current_categories
499
+ )
500
+
501
  # Function to handle export and show download button
502
  def handle_export(df, format_type):
503
  if df is None:
504
  return gr.File(visible=False)
505
  file_path = export_results(df, format_type)
506
  return gr.File(value=file_path, visible=True)
507
+
508
  # Function to improve classification based on validation report
509
+ def improve_classification(
510
+ df,
511
+ validation_report,
512
+ text_columns,
513
+ categories,
514
+ classifier_type,
515
+ show_explanations,
516
+ file,
517
+ ):
518
  """Improve classification based on validation report"""
519
  if df is None or not validation_report:
520
+ return (
521
+ df,
522
+ validation_report,
523
+ gr.Button(visible=False),
524
+ gr.CheckboxGroup(choices=[], value=[]),
525
+ )
526
+
527
  try:
528
  # Extract insights from validation report
529
  if client:
530
  prompt = VALIDATION_ANALYSIS_PROMPT.format(
531
  validation_report=validation_report,
532
+ current_categories=categories,
533
  )
534
  try:
535
  response = client.chat.completions.create(
536
  model="gpt-4",
537
  messages=[{"role": "user", "content": prompt}],
538
  temperature=0,
539
+ max_tokens=300,
540
+ )
541
+ improvements = json.loads(
542
+ response.choices[0].message.content.strip()
543
  )
544
+
 
545
  # Get current categories
546
+ current_categories = [
547
+ cat.strip() for cat in categories.split(",")
548
+ ]
549
+
550
  # If new categories are needed, suggest them based on the data
551
  if improvements.get("new_categories_needed", False):
552
  # Get sample texts for category suggestion
 
557
  else:
558
  temp_df = load_data(file.name)
559
  sample_texts.extend(temp_df[col].head(10).tolist())
560
+
561
  category_prompt = CATEGORY_IMPROVEMENT_PROMPT.format(
562
  current_categories=", ".join(current_categories),
563
+ analysis=improvements.get("analysis", ""),
564
+ sample_texts="\n---\n".join(sample_texts[:10]),
565
  )
566
+
567
  category_response = client.chat.completions.create(
568
  model="gpt-4",
569
  messages=[{"role": "user", "content": category_prompt}],
570
  temperature=0,
571
+ max_tokens=100,
572
  )
573
+
574
+ new_categories = [
575
+ cat.strip()
576
+ for cat in category_response.choices[0]
577
+ .message.content.strip()
578
+ .split(",")
579
+ ]
580
  # Combine current and new categories
581
  all_categories = current_categories + new_categories
582
  categories = ",".join(all_categories)
583
+
584
  # Process with improved parameters
585
  improved_df, new_validation = process_file(
586
  file,
587
  text_columns,
588
  categories,
589
  classifier_type,
590
+ show_explanations,
591
+ )
592
+
593
+ return (
594
+ improved_df,
595
+ new_validation,
596
+ gr.Button(visible=True),
597
+ gr.CheckboxGroup(
598
+ choices=all_categories, value=all_categories
599
+ ),
600
  )
 
 
601
  except Exception as e:
602
  print(f"Error in improvement process: {str(e)}")
603
+ return (
604
+ df,
605
+ validation_report,
606
+ gr.Button(visible=True),
607
+ gr.CheckboxGroup(
608
+ choices=current_categories, value=current_categories
609
+ ),
610
+ )
611
  else:
612
+ return (
613
+ df,
614
+ validation_report,
615
+ gr.Button(visible=True),
616
+ gr.CheckboxGroup(
617
+ choices=current_categories, value=current_categories
618
+ ),
619
+ )
620
  except Exception as e:
621
  print(f"Error in improvement process: {str(e)}")
622
+ return (
623
+ df,
624
+ validation_report,
625
+ gr.Button(visible=True),
626
+ gr.CheckboxGroup(
627
+ choices=current_categories, value=current_categories
628
+ ),
629
+ )
630
+
631
  # Connect functions
632
  load_categories_button.click(
633
  load_file_and_suggest_categories,
634
  inputs=[file_input],
635
  outputs=[
636
+ available_columns,
637
+ text_column,
638
  suggested_categories,
639
  new_category,
640
  add_category_button,
 
643
  classifier_type,
644
  show_explanations,
645
  process_button,
646
+ original_df,
647
+ ],
648
  )
649
+
650
  add_category_button.click(
651
  add_new_category,
652
  inputs=[suggested_categories, new_category],
653
+ outputs=[suggested_categories],
654
  )
655
+
656
  suggested_categories.change(
657
  update_categories_textbox,
658
  inputs=[suggested_categories],
659
+ outputs=[categories],
660
  )
661
+
662
  suggest_category_button.click(
663
  suggest_new_category,
664
  inputs=[file_input, suggested_categories, text_column],
665
+ outputs=[suggested_categories],
666
  )
667
+
668
  process_button.click(
669
+ lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df]
 
 
670
  ).then(
671
  process_file,
672
+ inputs=[
673
+ file_input,
674
+ text_column,
675
+ categories,
676
+ classifier_type,
677
+ show_explanations,
678
+ ],
679
+ outputs=[results_df, validation_output],
680
  ).then(
681
  show_results,
682
  inputs=[results_df, validation_output],
683
+ outputs=[results_row, csv_download, excel_download, results_df],
684
  ).then(
685
+ visualize_results, inputs=[results_df, text_column], outputs=[visualization]
 
 
686
  ).then(
687
+ lambda x: gr.Button(visible=True), inputs=[], outputs=[improve_button]
 
 
688
  )
689
+
690
  improve_button.click(
691
  improve_classification,
692
+ inputs=[
693
+ results_df,
694
+ validation_output,
695
+ text_column,
696
+ categories,
697
+ classifier_type,
698
+ show_explanations,
699
+ file_input,
700
+ ],
701
+ outputs=[
702
+ results_df,
703
+ validation_output,
704
+ improve_button,
705
+ suggested_categories,
706
+ ],
707
  ).then(
708
  show_results,
709
  inputs=[results_df, validation_output],
710
+ outputs=[results_row, csv_download, excel_download, results_df],
711
  ).then(
712
+ visualize_results, inputs=[results_df, text_column], outputs=[visualization]
 
 
713
  )
714
 
715
+
716
  def create_example_data():
717
  """Create example data for demonstration"""
718
  from utils import create_example_file
719
+
720
  example_path = create_example_file()
721
  return f"Example file created at: {example_path}"
722
 
723
+
724
  if __name__ == "__main__":
725
  # Create examples directory and sample file if it doesn't exist
726
  if not os.path.exists("examples"):
727
  create_example_data()
728
+
729
  # Launch the Gradio app
730
  demo.launch()
classifiers.py CHANGED
@@ -9,32 +9,34 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from typing import List, Dict, Any, Optional
10
  from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
11
 
 
12
  class BaseClassifier:
13
  """Base class for text classifiers"""
 
14
  def __init__(self):
15
  pass
16
-
17
  def classify(self, texts, categories=None):
18
  """
19
  Classify a list of texts into categories
20
-
21
  Args:
22
  texts (list): List of text strings to classify
23
  categories (list, optional): List of category names. If None, categories will be auto-detected
24
-
25
  Returns:
26
  list: List of classification results with categories, confidence scores, and explanations
27
  """
28
  raise NotImplementedError("Subclasses must implement this method")
29
-
30
  def _generate_default_categories(self, texts, num_clusters=5):
31
  """
32
  Generate default categories based on text clustering
33
-
34
  Args:
35
  texts (list): List of text strings
36
  num_clusters (int): Number of clusters to generate
37
-
38
  Returns:
39
  list: List of category names
40
  """
@@ -45,25 +47,23 @@ class BaseClassifier:
45
 
46
  class TFIDFClassifier(BaseClassifier):
47
  """Classifier using TF-IDF and clustering for fast classification"""
48
-
49
  def __init__(self):
50
  super().__init__()
51
  self.vectorizer = TfidfVectorizer(
52
- max_features=1000,
53
- stop_words='english',
54
- ngram_range=(1, 2)
55
  )
56
  self.model = None
57
  self.feature_names = None
58
  self.categories = None
59
  self.centroids = None
60
-
61
  def classify(self, texts, categories=None):
62
  """Classify texts using TF-IDF and clustering"""
63
  # Vectorize the texts
64
  X = self.vectorizer.fit_transform(texts)
65
  self.feature_names = self.vectorizer.get_feature_names_out()
66
-
67
  # Auto-detect categories if not provided
68
  if not categories:
69
  num_clusters = min(5, len(texts)) # Don't create more clusters than texts
@@ -71,98 +71,106 @@ class TFIDFClassifier(BaseClassifier):
71
  else:
72
  self.categories = categories
73
  num_clusters = len(categories)
74
-
75
  # Cluster the texts
76
  self.model = KMeans(n_clusters=num_clusters, random_state=42)
77
  clusters = self.model.fit_predict(X)
78
  self.centroids = self.model.cluster_centers_
79
-
80
  # Calculate distances to centroids for confidence
81
  distances = self._calculate_distances(X)
82
-
83
  # Prepare results
84
  results = []
85
  for i, text in enumerate(texts):
86
  cluster_idx = clusters[i]
87
-
88
  # Calculate confidence (inverse of distance, normalized)
89
  confidence = self._calculate_confidence(distances[i])
90
-
91
  # Create explanation
92
  explanation = self._generate_explanation(X[i], cluster_idx)
93
-
94
- results.append({
95
- "category": self.categories[cluster_idx],
96
- "confidence": confidence,
97
- "explanation": explanation
98
- })
99
-
 
 
100
  return results
101
-
102
  def _calculate_distances(self, X):
103
  """Calculate distances from each point to each centroid"""
104
- return np.sqrt(((X.toarray()[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2).sum(axis=2))
105
-
 
 
 
 
106
  def _calculate_confidence(self, distances):
107
  """Convert distances to confidence scores (0-100)"""
108
  min_dist = np.min(distances)
109
  max_dist = np.max(distances)
110
-
111
  # Normalize and invert (smaller distance = higher confidence)
112
  if max_dist == min_dist:
113
  return 70 # Default mid-range confidence when all distances are equal
114
-
115
  normalized_dist = (distances - min_dist) / (max_dist - min_dist)
116
  min_normalized = np.min(normalized_dist)
117
-
118
  # Invert and scale to 50-100 range (TF-IDF is never 100% confident)
119
  confidence = 100 - (min_normalized * 50)
120
  return round(confidence, 1)
121
-
122
  def _generate_explanation(self, text_vector, cluster_idx):
123
  """Generate an explanation for the classification"""
124
  # Get the most important features for this cluster
125
  centroid = self.centroids[cluster_idx]
126
-
127
  # Get indices of top features for this text
128
  text_array = text_vector.toarray()[0]
129
  top_indices = text_array.argsort()[-5:][::-1]
130
-
131
  # Get the feature names for these indices
132
  top_features = [self.feature_names[i] for i in top_indices if text_array[i] > 0]
133
-
134
  if not top_features:
135
  return "No significant features identified for this classification."
136
-
137
  explanation = f"Classification based on key terms: {', '.join(top_features)}"
138
  return explanation
139
 
140
 
141
  class LLMClassifier(BaseClassifier):
142
  """Classifier using a Large Language Model for more accurate but slower classification"""
143
-
144
  def __init__(self, client, model="gpt-3.5-turbo"):
145
  super().__init__()
146
  self.client = client
147
  self.model = model
148
-
149
- def classify(self, texts: List[str], categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
 
 
150
  """Classify texts using an LLM with parallel processing"""
151
  if not categories:
152
  # First, use LLM to generate appropriate categories
153
  categories = self._suggest_categories(texts)
154
-
155
  # Process texts in parallel
156
  with ThreadPoolExecutor(max_workers=10) as executor:
157
  # Submit all tasks with their original indices
158
  future_to_index = {
159
- executor.submit(self._classify_text, text, categories): idx
160
  for idx, text in enumerate(texts)
161
  }
162
-
163
  # Initialize results list with None values
164
  results = [None] * len(texts)
165
-
166
  # Collect results as they complete
167
  for future in as_completed(future_to_index):
168
  original_idx = future_to_index[future]
@@ -174,11 +182,11 @@ class LLMClassifier(BaseClassifier):
174
  results[original_idx] = {
175
  "category": categories[0],
176
  "confidence": 50,
177
- "explanation": f"Error during classification: {str(e)}"
178
  }
179
-
180
  return results
181
-
182
  def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
183
  """Use LLM to suggest appropriate categories for the dataset"""
184
  # Take a sample of texts to avoid token limitations
@@ -186,54 +194,55 @@ class LLMClassifier(BaseClassifier):
186
  sample_texts = random.sample(texts, sample_size)
187
  else:
188
  sample_texts = texts
189
-
190
  prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
191
-
192
  try:
193
  response = self.client.chat.completions.create(
194
  model=self.model,
195
  messages=[{"role": "user", "content": prompt}],
196
  temperature=0.2,
197
- max_tokens=100
198
  )
199
-
200
  # Parse response to get categories
201
  categories_text = response.choices[0].message.content.strip()
202
  categories = [cat.strip() for cat in categories_text.split(",")]
203
-
204
  return categories
205
  except Exception as e:
206
  # Fallback to default categories on error
207
  print(f"Error suggesting categories: {str(e)}")
208
  return self._generate_default_categories(texts)
209
-
210
  def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
211
  """Use LLM to classify a single text"""
212
  prompt = TEXT_CLASSIFICATION_PROMPT.format(
213
- categories=", ".join(categories),
214
- text=text
215
  )
216
-
217
  try:
218
  response = self.client.chat.completions.create(
219
  model=self.model,
220
  messages=[{"role": "user", "content": prompt}],
221
  temperature=0,
222
- max_tokens=200
223
  )
224
-
225
  # Parse JSON response
226
  response_text = response.choices[0].message.content.strip()
227
-
228
  result = json.loads(response_text)
229
  # Ensure all required fields are present
230
  if not all(k in result for k in ["category", "confidence", "explanation"]):
231
  raise ValueError("Missing required fields in LLM response")
232
-
233
  # Validate category is in the list
234
  if result["category"] not in categories:
235
- result["category"] = categories[0] # Default to first category if invalid
236
-
 
 
237
  # Validate confidence is a number between 0 and 100
238
  try:
239
  result["confidence"] = float(result["confidence"])
@@ -241,7 +250,7 @@ class LLMClassifier(BaseClassifier):
241
  result["confidence"] = 50
242
  except:
243
  result["confidence"] = 50
244
-
245
  return result
246
  except json.JSONDecodeError:
247
  # Fall back to simple parsing if JSON fails
@@ -250,12 +259,9 @@ class LLMClassifier(BaseClassifier):
250
  if cat.lower() in response_text.lower():
251
  category = cat
252
  break
253
-
254
  return {
255
  "category": category,
256
  "confidence": 50,
257
- "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)"
258
  }
259
-
260
-
261
-
 
9
  from typing import List, Dict, Any, Optional
10
  from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
11
 
12
+
13
  class BaseClassifier:
14
  """Base class for text classifiers"""
15
+
16
  def __init__(self):
17
  pass
18
+
19
  def classify(self, texts, categories=None):
20
  """
21
  Classify a list of texts into categories
22
+
23
  Args:
24
  texts (list): List of text strings to classify
25
  categories (list, optional): List of category names. If None, categories will be auto-detected
26
+
27
  Returns:
28
  list: List of classification results with categories, confidence scores, and explanations
29
  """
30
  raise NotImplementedError("Subclasses must implement this method")
31
+
32
  def _generate_default_categories(self, texts, num_clusters=5):
33
  """
34
  Generate default categories based on text clustering
35
+
36
  Args:
37
  texts (list): List of text strings
38
  num_clusters (int): Number of clusters to generate
39
+
40
  Returns:
41
  list: List of category names
42
  """
 
47
 
48
  class TFIDFClassifier(BaseClassifier):
49
  """Classifier using TF-IDF and clustering for fast classification"""
50
+
51
  def __init__(self):
52
  super().__init__()
53
  self.vectorizer = TfidfVectorizer(
54
+ max_features=1000, stop_words="english", ngram_range=(1, 2)
 
 
55
  )
56
  self.model = None
57
  self.feature_names = None
58
  self.categories = None
59
  self.centroids = None
60
+
61
  def classify(self, texts, categories=None):
62
  """Classify texts using TF-IDF and clustering"""
63
  # Vectorize the texts
64
  X = self.vectorizer.fit_transform(texts)
65
  self.feature_names = self.vectorizer.get_feature_names_out()
66
+
67
  # Auto-detect categories if not provided
68
  if not categories:
69
  num_clusters = min(5, len(texts)) # Don't create more clusters than texts
 
71
  else:
72
  self.categories = categories
73
  num_clusters = len(categories)
74
+
75
  # Cluster the texts
76
  self.model = KMeans(n_clusters=num_clusters, random_state=42)
77
  clusters = self.model.fit_predict(X)
78
  self.centroids = self.model.cluster_centers_
79
+
80
  # Calculate distances to centroids for confidence
81
  distances = self._calculate_distances(X)
82
+
83
  # Prepare results
84
  results = []
85
  for i, text in enumerate(texts):
86
  cluster_idx = clusters[i]
87
+
88
  # Calculate confidence (inverse of distance, normalized)
89
  confidence = self._calculate_confidence(distances[i])
90
+
91
  # Create explanation
92
  explanation = self._generate_explanation(X[i], cluster_idx)
93
+
94
+ results.append(
95
+ {
96
+ "category": self.categories[cluster_idx],
97
+ "confidence": confidence,
98
+ "explanation": explanation,
99
+ }
100
+ )
101
+
102
  return results
103
+
104
  def _calculate_distances(self, X):
105
  """Calculate distances from each point to each centroid"""
106
+ return np.sqrt(
107
+ (
108
+ (X.toarray()[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2
109
+ ).sum(axis=2)
110
+ )
111
+
112
  def _calculate_confidence(self, distances):
113
  """Convert distances to confidence scores (0-100)"""
114
  min_dist = np.min(distances)
115
  max_dist = np.max(distances)
116
+
117
  # Normalize and invert (smaller distance = higher confidence)
118
  if max_dist == min_dist:
119
  return 70 # Default mid-range confidence when all distances are equal
120
+
121
  normalized_dist = (distances - min_dist) / (max_dist - min_dist)
122
  min_normalized = np.min(normalized_dist)
123
+
124
  # Invert and scale to 50-100 range (TF-IDF is never 100% confident)
125
  confidence = 100 - (min_normalized * 50)
126
  return round(confidence, 1)
127
+
128
  def _generate_explanation(self, text_vector, cluster_idx):
129
  """Generate an explanation for the classification"""
130
  # Get the most important features for this cluster
131
  centroid = self.centroids[cluster_idx]
132
+
133
  # Get indices of top features for this text
134
  text_array = text_vector.toarray()[0]
135
  top_indices = text_array.argsort()[-5:][::-1]
136
+
137
  # Get the feature names for these indices
138
  top_features = [self.feature_names[i] for i in top_indices if text_array[i] > 0]
139
+
140
  if not top_features:
141
  return "No significant features identified for this classification."
142
+
143
  explanation = f"Classification based on key terms: {', '.join(top_features)}"
144
  return explanation
145
 
146
 
147
  class LLMClassifier(BaseClassifier):
148
  """Classifier using a Large Language Model for more accurate but slower classification"""
149
+
150
  def __init__(self, client, model="gpt-3.5-turbo"):
151
  super().__init__()
152
  self.client = client
153
  self.model = model
154
+
155
+ def classify(
156
+ self, texts: List[str], categories: Optional[List[str]] = None
157
+ ) -> List[Dict[str, Any]]:
158
  """Classify texts using an LLM with parallel processing"""
159
  if not categories:
160
  # First, use LLM to generate appropriate categories
161
  categories = self._suggest_categories(texts)
162
+
163
  # Process texts in parallel
164
  with ThreadPoolExecutor(max_workers=10) as executor:
165
  # Submit all tasks with their original indices
166
  future_to_index = {
167
+ executor.submit(self._classify_text, text, categories): idx
168
  for idx, text in enumerate(texts)
169
  }
170
+
171
  # Initialize results list with None values
172
  results = [None] * len(texts)
173
+
174
  # Collect results as they complete
175
  for future in as_completed(future_to_index):
176
  original_idx = future_to_index[future]
 
182
  results[original_idx] = {
183
  "category": categories[0],
184
  "confidence": 50,
185
+ "explanation": f"Error during classification: {str(e)}",
186
  }
187
+
188
  return results
189
+
190
  def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
191
  """Use LLM to suggest appropriate categories for the dataset"""
192
  # Take a sample of texts to avoid token limitations
 
194
  sample_texts = random.sample(texts, sample_size)
195
  else:
196
  sample_texts = texts
197
+
198
  prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
199
+
200
  try:
201
  response = self.client.chat.completions.create(
202
  model=self.model,
203
  messages=[{"role": "user", "content": prompt}],
204
  temperature=0.2,
205
+ max_tokens=100,
206
  )
207
+
208
  # Parse response to get categories
209
  categories_text = response.choices[0].message.content.strip()
210
  categories = [cat.strip() for cat in categories_text.split(",")]
211
+
212
  return categories
213
  except Exception as e:
214
  # Fallback to default categories on error
215
  print(f"Error suggesting categories: {str(e)}")
216
  return self._generate_default_categories(texts)
217
+
218
  def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
219
  """Use LLM to classify a single text"""
220
  prompt = TEXT_CLASSIFICATION_PROMPT.format(
221
+ categories=", ".join(categories), text=text
 
222
  )
223
+
224
  try:
225
  response = self.client.chat.completions.create(
226
  model=self.model,
227
  messages=[{"role": "user", "content": prompt}],
228
  temperature=0,
229
+ max_tokens=200,
230
  )
231
+
232
  # Parse JSON response
233
  response_text = response.choices[0].message.content.strip()
234
+
235
  result = json.loads(response_text)
236
  # Ensure all required fields are present
237
  if not all(k in result for k in ["category", "confidence", "explanation"]):
238
  raise ValueError("Missing required fields in LLM response")
239
+
240
  # Validate category is in the list
241
  if result["category"] not in categories:
242
+ result["category"] = categories[
243
+ 0
244
+ ] # Default to first category if invalid
245
+
246
  # Validate confidence is a number between 0 and 100
247
  try:
248
  result["confidence"] = float(result["confidence"])
 
250
  result["confidence"] = 50
251
  except:
252
  result["confidence"] = 50
253
+
254
  return result
255
  except json.JSONDecodeError:
256
  # Fall back to simple parsing if JSON fails
 
259
  if cat.lower() in response_text.lower():
260
  category = cat
261
  break
262
+
263
  return {
264
  "category": category,
265
  "confidence": 50,
266
+ "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
267
  }
 
 
 
prompts.py CHANGED
@@ -60,4 +60,4 @@ Example texts:
60
  {}
61
 
62
  Return your answer as a comma-separated list of new category names only.
63
- """
 
60
  {}
61
 
62
  Return your answer as a comma-separated list of new category names only.
63
+ """
utils.py CHANGED
@@ -6,61 +6,66 @@ from sklearn.decomposition import PCA
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  import tempfile
8
 
 
9
  def load_data(file_path):
10
  """
11
  Load data from an Excel or CSV file
12
-
13
  Args:
14
  file_path (str): Path to the file
15
-
16
  Returns:
17
  pd.DataFrame: Loaded data
18
  """
19
  file_ext = os.path.splitext(file_path)[1].lower()
20
-
21
- if file_ext == '.xlsx' or file_ext == '.xls':
22
  return pd.read_excel(file_path)
23
- elif file_ext == '.csv':
24
  return pd.read_csv(file_path)
25
  else:
26
- raise ValueError(f"Unsupported file format: {file_ext}. Please upload an Excel or CSV file.")
 
 
 
27
 
28
  def export_data(df, file_name, format_type="excel"):
29
  """
30
  Export dataframe to file
31
-
32
  Args:
33
  df (pd.DataFrame): Dataframe to export
34
  file_name (str): Name of the output file
35
  format_type (str): "excel" or "csv"
36
-
37
  Returns:
38
  str: Path to the exported file
39
  """
40
  # Create export directory if it doesn't exist
41
  export_dir = "exports"
42
  os.makedirs(export_dir, exist_ok=True)
43
-
44
  # Full path for the export file
45
  export_path = os.path.join(export_dir, file_name)
46
-
47
  # Export based on format type
48
  if format_type == "excel":
49
  df.to_excel(export_path, index=False)
50
  else:
51
  df.to_csv(export_path, index=False)
52
-
53
  return export_path
54
 
 
55
  def visualize_results(df, text_column, category_column="Category"):
56
  """
57
  Create visualization of classification results
58
-
59
  Args:
60
  df (pd.DataFrame): Dataframe with classification results
61
  text_column (str): Name of the column containing text data
62
  category_column (str): Name of the column containing categories
63
-
64
  Returns:
65
  matplotlib.figure.Figure: Visualization figure
66
  """
@@ -68,52 +73,58 @@ def visualize_results(df, text_column, category_column="Category"):
68
  if category_column not in df.columns:
69
  # Create a simple figure with a message
70
  fig, ax = plt.subplots(figsize=(10, 6))
71
- ax.text(0.5, 0.5, "No categories to display",
72
- ha='center', va='center', fontsize=12)
73
- ax.set_title('No Classification Results Available')
 
74
  plt.tight_layout()
75
  return fig
76
-
77
  # Get categories and their counts
78
  category_counts = df[category_column].value_counts()
79
-
80
  # Create a new figure
81
  fig, ax = plt.subplots(figsize=(10, 6))
82
-
83
  # Create the histogram
84
  bars = ax.bar(category_counts.index, category_counts.values)
85
-
86
  # Add value labels on top of each bar
87
  for bar in bars:
88
  height = bar.get_height()
89
- ax.text(bar.get_x() + bar.get_width()/2., height,
90
- f'{int(height)}',
91
- ha='center', va='bottom')
92
-
 
 
 
 
93
  # Customize the plot
94
- ax.set_xlabel('Categories')
95
- ax.set_ylabel('Number of Texts')
96
- ax.set_title('Distribution of Classified Texts')
97
-
98
  # Rotate x-axis labels if they're too long
99
- plt.xticks(rotation=45, ha='right')
100
-
101
  # Add grid
102
- ax.grid(True, linestyle='--', alpha=0.7)
103
-
104
  plt.tight_layout()
105
-
106
  return fig
107
 
 
108
  def validate_results(df, text_columns, client):
109
  """
110
  Use LLM to validate the classification results
111
-
112
  Args:
113
  df (pd.DataFrame): Dataframe with classification results
114
  text_columns (list): List of column names containing text data
115
  client: LiteLLM client
116
-
117
  Returns:
118
  str: Validation report
119
  """
@@ -121,7 +132,7 @@ def validate_results(df, text_columns, client):
121
  # Sample a few rows for validation
122
  sample_size = min(5, len(df))
123
  sample_df = df.sample(n=sample_size, random_state=42)
124
-
125
  # Build validation prompt
126
  validation_prompts = []
127
  for _, row in sample_df.iterrows():
@@ -129,11 +140,11 @@ def validate_results(df, text_columns, client):
129
  text = " ".join(str(row[col]) for col in text_columns)
130
  assigned_category = row["Category"]
131
  confidence = row["Confidence"]
132
-
133
  validation_prompts.append(
134
  f"Text: {text}\nAssigned Category: {assigned_category}\nConfidence: {confidence}\n"
135
  )
136
-
137
  prompt = """
138
  As a validation expert, review the following text classifications and provide feedback.
139
  For each text, assess whether the assigned category seems appropriate:
@@ -146,19 +157,21 @@ def validate_results(df, text_columns, client):
146
  3. Suggestions for improvement
147
 
148
  Keep your response under 300 words.
149
- """.format("\n---\n".join(validation_prompts))
150
-
 
 
151
  # Call LLM API
152
  response = client.chat.completions.create(
153
  model="gpt-3.5-turbo",
154
  messages=[{"role": "user", "content": prompt}],
155
  temperature=0.3,
156
- max_tokens=400
157
  )
158
-
159
  validation_report = response.choices[0].message.content.strip()
160
  return validation_report
161
-
162
  except Exception as e:
163
  return f"Validation failed: {str(e)}"
164
 
@@ -166,7 +179,7 @@ def validate_results(df, text_columns, client):
166
  def create_example_file():
167
  """
168
  Create an example CSV file for testing
169
-
170
  Returns:
171
  str: Path to the created file
172
  """
@@ -182,17 +195,17 @@ def create_example_file():
182
  "It's okay, nothing special but gets the job done.",
183
  "I'm extremely disappointed with the quality of this product.",
184
  "This is the best purchase I've made all year!",
185
- "It's reasonably priced and works as expected."
186
  ]
187
  }
188
-
189
  # Create dataframe
190
  df = pd.DataFrame(data)
191
-
192
  # Save to a CSV file
193
  example_dir = "examples"
194
  os.makedirs(example_dir, exist_ok=True)
195
  file_path = os.path.join(example_dir, "sample_reviews.csv")
196
  df.to_csv(file_path, index=False)
197
-
198
  return file_path
 
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  import tempfile
8
 
9
+
10
  def load_data(file_path):
11
  """
12
  Load data from an Excel or CSV file
13
+
14
  Args:
15
  file_path (str): Path to the file
16
+
17
  Returns:
18
  pd.DataFrame: Loaded data
19
  """
20
  file_ext = os.path.splitext(file_path)[1].lower()
21
+
22
+ if file_ext == ".xlsx" or file_ext == ".xls":
23
  return pd.read_excel(file_path)
24
+ elif file_ext == ".csv":
25
  return pd.read_csv(file_path)
26
  else:
27
+ raise ValueError(
28
+ f"Unsupported file format: {file_ext}. Please upload an Excel or CSV file."
29
+ )
30
+
31
 
32
  def export_data(df, file_name, format_type="excel"):
33
  """
34
  Export dataframe to file
35
+
36
  Args:
37
  df (pd.DataFrame): Dataframe to export
38
  file_name (str): Name of the output file
39
  format_type (str): "excel" or "csv"
40
+
41
  Returns:
42
  str: Path to the exported file
43
  """
44
  # Create export directory if it doesn't exist
45
  export_dir = "exports"
46
  os.makedirs(export_dir, exist_ok=True)
47
+
48
  # Full path for the export file
49
  export_path = os.path.join(export_dir, file_name)
50
+
51
  # Export based on format type
52
  if format_type == "excel":
53
  df.to_excel(export_path, index=False)
54
  else:
55
  df.to_csv(export_path, index=False)
56
+
57
  return export_path
58
 
59
+
60
  def visualize_results(df, text_column, category_column="Category"):
61
  """
62
  Create visualization of classification results
63
+
64
  Args:
65
  df (pd.DataFrame): Dataframe with classification results
66
  text_column (str): Name of the column containing text data
67
  category_column (str): Name of the column containing categories
68
+
69
  Returns:
70
  matplotlib.figure.Figure: Visualization figure
71
  """
 
73
  if category_column not in df.columns:
74
  # Create a simple figure with a message
75
  fig, ax = plt.subplots(figsize=(10, 6))
76
+ ax.text(
77
+ 0.5, 0.5, "No categories to display", ha="center", va="center", fontsize=12
78
+ )
79
+ ax.set_title("No Classification Results Available")
80
  plt.tight_layout()
81
  return fig
82
+
83
  # Get categories and their counts
84
  category_counts = df[category_column].value_counts()
85
+
86
  # Create a new figure
87
  fig, ax = plt.subplots(figsize=(10, 6))
88
+
89
  # Create the histogram
90
  bars = ax.bar(category_counts.index, category_counts.values)
91
+
92
  # Add value labels on top of each bar
93
  for bar in bars:
94
  height = bar.get_height()
95
+ ax.text(
96
+ bar.get_x() + bar.get_width() / 2.0,
97
+ height,
98
+ f"{int(height)}",
99
+ ha="center",
100
+ va="bottom",
101
+ )
102
+
103
  # Customize the plot
104
+ ax.set_xlabel("Categories")
105
+ ax.set_ylabel("Number of Texts")
106
+ ax.set_title("Distribution of Classified Texts")
107
+
108
  # Rotate x-axis labels if they're too long
109
+ plt.xticks(rotation=45, ha="right")
110
+
111
  # Add grid
112
+ ax.grid(True, linestyle="--", alpha=0.7)
113
+
114
  plt.tight_layout()
115
+
116
  return fig
117
 
118
+
119
  def validate_results(df, text_columns, client):
120
  """
121
  Use LLM to validate the classification results
122
+
123
  Args:
124
  df (pd.DataFrame): Dataframe with classification results
125
  text_columns (list): List of column names containing text data
126
  client: LiteLLM client
127
+
128
  Returns:
129
  str: Validation report
130
  """
 
132
  # Sample a few rows for validation
133
  sample_size = min(5, len(df))
134
  sample_df = df.sample(n=sample_size, random_state=42)
135
+
136
  # Build validation prompt
137
  validation_prompts = []
138
  for _, row in sample_df.iterrows():
 
140
  text = " ".join(str(row[col]) for col in text_columns)
141
  assigned_category = row["Category"]
142
  confidence = row["Confidence"]
143
+
144
  validation_prompts.append(
145
  f"Text: {text}\nAssigned Category: {assigned_category}\nConfidence: {confidence}\n"
146
  )
147
+
148
  prompt = """
149
  As a validation expert, review the following text classifications and provide feedback.
150
  For each text, assess whether the assigned category seems appropriate:
 
157
  3. Suggestions for improvement
158
 
159
  Keep your response under 300 words.
160
+ """.format(
161
+ "\n---\n".join(validation_prompts)
162
+ )
163
+
164
  # Call LLM API
165
  response = client.chat.completions.create(
166
  model="gpt-3.5-turbo",
167
  messages=[{"role": "user", "content": prompt}],
168
  temperature=0.3,
169
+ max_tokens=400,
170
  )
171
+
172
  validation_report = response.choices[0].message.content.strip()
173
  return validation_report
174
+
175
  except Exception as e:
176
  return f"Validation failed: {str(e)}"
177
 
 
179
  def create_example_file():
180
  """
181
  Create an example CSV file for testing
182
+
183
  Returns:
184
  str: Path to the created file
185
  """
 
195
  "It's okay, nothing special but gets the job done.",
196
  "I'm extremely disappointed with the quality of this product.",
197
  "This is the best purchase I've made all year!",
198
+ "It's reasonably priced and works as expected.",
199
  ]
200
  }
201
+
202
  # Create dataframe
203
  df = pd.DataFrame(data)
204
+
205
  # Save to a CSV file
206
  example_dir = "examples"
207
  os.makedirs(example_dir, exist_ok=True)
208
  file_path = os.path.join(example_dir, "sample_reviews.csv")
209
  df.to_csv(file_path, index=False)
210
+
211
  return file_path