classifieur / utils.py
simondh's picture
add endpoints
156898c
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer
import tempfile
from prompts import VALIDATION_PROMPT
from typing import List, Optional, Any, Union, Tuple
from pathlib import Path
from matplotlib.figure import Figure
def load_data(file_path: Union[str, Path]) -> pd.DataFrame:
"""
Load data from an Excel or CSV file
Args:
file_path (str): Path to the file
Returns:
pd.DataFrame: Loaded data
"""
file_ext: str = os.path.splitext(file_path)[1].lower()
if file_ext == ".xlsx" or file_ext == ".xls":
return pd.read_excel(file_path)
elif file_ext == ".csv":
return pd.read_csv(file_path)
else:
raise ValueError(
f"Unsupported file format: {file_ext}. Please upload an Excel or CSV file."
)
def analyze_text_columns(df: pd.DataFrame) -> List[str]:
"""
Analyze columns to suggest text columns based on content analysis
Args:
df (pd.DataFrame): Input dataframe
Returns:
List[str]: List of suggested text columns
"""
suggested_text_columns: List[str] = []
for col in df.columns:
if df[col].dtype == "object": # String type
# Check if column contains mostly text (not just numbers or dates)
sample = df[col].head(100).dropna()
if len(sample) > 0:
# Check if most values contain spaces (indicating text)
text_ratio = sum(" " in str(val) for val in sample) / len(sample)
if text_ratio > 0.3: # If more than 30% of values contain spaces
suggested_text_columns.append(col)
# If no columns were suggested, use all object columns
if not suggested_text_columns:
suggested_text_columns = [col for col in df.columns if df[col].dtype == "object"]
return suggested_text_columns
def get_sample_texts(df: pd.DataFrame, text_columns: List[str], sample_size: int = 5) -> List[str]:
"""
Get sample texts from specified columns
Args:
df (pd.DataFrame): Input dataframe
text_columns (List[str]): List of text column names
sample_size (int): Number of samples to take from each column
Returns:
List[str]: List of sample texts
"""
sample_texts: List[str] = []
for col in text_columns:
sample_texts.extend(df[col].head(sample_size).tolist())
return sample_texts
def export_data(df: pd.DataFrame, file_name: str, format_type: str = "excel") -> str:
"""
Export dataframe to file
Args:
df (pd.DataFrame): Dataframe to export
file_name (str): Name of the output file
format_type (str): "excel" or "csv"
Returns:
str: Path to the exported file
"""
# Create export directory if it doesn't exist
export_dir: str = "exports"
os.makedirs(export_dir, exist_ok=True)
# Full path for the export file
export_path: str = os.path.join(export_dir, file_name)
# Export based on format type
if format_type == "excel":
df.to_excel(export_path, index=False)
else:
df.to_csv(export_path, index=False)
return export_path
def visualize_results(df: pd.DataFrame, text_column: str, category_column: str = "Category") -> Figure:
"""
Create visualization of classification results
Args:
df (pd.DataFrame): Dataframe with classification results
text_column (str): Name of the column containing text data
category_column (str): Name of the column containing categories
Returns:
matplotlib.figure.Figure: Visualization figure
"""
# Check if category column exists
if category_column not in df.columns:
# Create a simple figure with a message
fig: Figure
ax: Any
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(
0.5, 0.5, "No categories to display", ha="center", va="center", fontsize=12
)
ax.set_title("No Classification Results Available")
plt.tight_layout()
return fig
# Get categories and their counts
category_counts: pd.Series = df[category_column].value_counts()
# Create a new figure
fig: Figure
ax: Any
fig, ax = plt.subplots(figsize=(10, 6))
# Create the histogram
bars: Any = ax.bar(category_counts.index, category_counts.values)
# Add value labels on top of each bar
for bar in bars:
height: float = bar.get_height()
ax.text(
bar.get_x() + bar.get_width() / 2.0,
height,
f"{int(height)}",
ha="center",
va="bottom",
)
# Customize the plot
ax.set_xlabel("Categories")
ax.set_ylabel("Number of Texts")
ax.set_title("Distribution of Classified Texts")
# Rotate x-axis labels if they're too long
plt.xticks(rotation=45, ha="right")
# Add grid
ax.grid(True, linestyle="--", alpha=0.7)
plt.tight_layout()
return fig
def validate_results(df: pd.DataFrame, text_columns: List[str], client: Any) -> str:
"""
Use LLM to validate the classification results
Args:
df (pd.DataFrame): Dataframe with classification results
text_columns (list): List of column names containing text data
client: LiteLLM client
Returns:
str: Validation report
"""
try:
# Sample a few rows for validation
sample_size: int = min(5, len(df))
sample_df: pd.DataFrame = df.sample(n=sample_size, random_state=42)
# Build validation prompts
validation_prompts: List[str] = []
for _, row in sample_df.iterrows():
# Combine text from all selected columns
text: str = " ".join(str(row[col]) for col in text_columns)
assigned_category: str = row["Category"]
confidence: float = row["Confidence"]
validation_prompts.append(
f"Text: {text}\nAssigned Category: {assigned_category}\nConfidence: {confidence}\n"
)
# Use the prompt from prompts.py
prompt: str = VALIDATION_PROMPT.format("\n---\n".join(validation_prompts))
# Call LLM API
response: Any = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=400,
)
validation_report: str = response.choices[0].message.content.strip()
return validation_report
except Exception as e:
return f"Validation failed: {str(e)}"