Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import asyncio | |
import json | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
import matplotlib.pyplot as plt | |
import logging | |
from dotenv import load_dotenv | |
from process import update_api_key, process_file_async, export_results, improve_classification | |
from client import get_client, initialize_client | |
from utils import load_data, visualize_results, analyze_text_columns, get_sample_texts | |
from classifiers.llm import LLMClassifier | |
# Load environment variables from .env file | |
load_dotenv() | |
# Import local modules | |
from prompts import ( | |
CATEGORY_SUGGESTION_PROMPT, | |
ADDITIONAL_CATEGORY_PROMPT, | |
VALIDATION_ANALYSIS_PROMPT, | |
CATEGORY_IMPROVEMENT_PROMPT, | |
) | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# Initialize API key from environment variable | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
# Initialize client if API key is available | |
if OPENAI_API_KEY: | |
success, message = initialize_client(OPENAI_API_KEY) | |
if success: | |
logging.info("OpenAI client initialized successfully") | |
else: | |
logging.error(f"Failed to initialize OpenAI client: {message}") | |
# Create Gradio interface | |
with gr.Blocks(title="Text Classification System") as demo: | |
gr.Markdown("# Text Classification System") | |
gr.Markdown("Upload your data file (Excel/CSV) and classify text using AI") | |
with gr.Tab("Setup"): | |
api_key_input = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="Enter your API key here", | |
type="password", | |
value=OPENAI_API_KEY, | |
) | |
api_key_button = gr.Button("Update API Key") | |
api_key_message = gr.Textbox(label="Status", interactive=False) | |
# Display current API status | |
client = get_client() | |
api_status = "API Key is set" if client else "No API Key found. Please set one." | |
gr.Markdown(f"**Current API Status**: {api_status}") | |
api_key_button.click( | |
update_api_key, inputs=[api_key_input], outputs=[api_key_message] | |
) | |
with gr.Tab("Classify Data"): | |
with gr.Column(): | |
file_input = gr.File(label="Upload Excel/CSV File") | |
# Variable to store available columns | |
available_columns = gr.State([]) | |
# Button to load file and suggest categories | |
load_categories_button = gr.Button("Load File") | |
# Display original dataframe | |
original_df = gr.Dataframe( | |
label="Original Data", interactive=False, visible=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
suggested_categories = gr.CheckboxGroup( | |
label="Suggested Categories", | |
choices=[], | |
value=[], | |
interactive=True, | |
visible=False, | |
) | |
new_category = gr.Textbox( | |
label="Add New Category", | |
placeholder="Enter a new category name", | |
visible=False, | |
) | |
with gr.Row(): | |
add_category_button = gr.Button("Add Category", visible=False) | |
suggest_category_button = gr.Button( | |
"Suggest Category", visible=False | |
) | |
# Original categories input (hidden) | |
categories = gr.Textbox(visible=False) | |
with gr.Column(): | |
text_column = gr.CheckboxGroup( | |
label="Select Text Columns", | |
choices=[], | |
interactive=True, | |
visible=False, | |
) | |
classifier_type = gr.Dropdown( | |
choices=[ | |
("TF-IDF (Rapide, <1000 lignes)", "tfidf"), | |
("LLM GPT-3.5 (Fiable, <1000 lignes)", "gpt35"), | |
("LLM GPT-4 (Très fiable, <500 lignes)", "gpt4"), | |
("TF-IDF + LLM (Hybride, >1000 lignes)", "hybrid"), | |
], | |
label="Modèle de classification", | |
value="gpt35", | |
visible=False, | |
) | |
show_explanations = gr.Checkbox( | |
label="Show Explanations", value=True, visible=False | |
) | |
process_button = gr.Button("Process and Classify", visible=False) | |
results_df = gr.Dataframe(interactive=True, visible=False) | |
# Create containers for visualization and validation report | |
with gr.Row(visible=False) as results_row: | |
with gr.Column(): | |
visualization = gr.Plot(label="Classification Distribution") | |
with gr.Row(): | |
csv_download = gr.File(label="Download CSV", visible=False) | |
excel_download = gr.File(label="Download Excel", visible=False) | |
with gr.Column(): | |
validation_output = gr.Textbox( | |
label="Validation Report", interactive=True, | |
lines=15 | |
) | |
improve_button = gr.Button( | |
"Improve Classification with Report", visible=False | |
) | |
# Function to load file and suggest categories | |
async def load_file_and_suggest_categories(file): | |
if not file: | |
return ( | |
[], | |
gr.CheckboxGroup(choices=[]), | |
gr.CheckboxGroup(choices=[], visible=False), | |
gr.Textbox(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.CheckboxGroup(choices=[], visible=False), | |
gr.Dropdown(visible=False), | |
gr.Checkbox(visible=False), | |
gr.Button(visible=False), | |
gr.Dataframe(visible=False), | |
) | |
try: | |
df = load_data(file.name) | |
columns = list(df.columns) | |
# Analyze columns to suggest text columns | |
suggested_text_columns = analyze_text_columns(df) | |
# Get sample texts for category suggestion | |
sample_texts = get_sample_texts(df, suggested_text_columns) | |
# Use LLM to suggest categories | |
if client: | |
classifier = LLMClassifier(client=client) | |
suggested_cats = await classifier.suggest_categories_from_texts(sample_texts) | |
else: | |
suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"] | |
return ( | |
columns, | |
gr.CheckboxGroup(choices=columns, value=suggested_text_columns), | |
gr.CheckboxGroup( | |
choices=suggested_cats, value=suggested_cats, visible=True | |
), | |
gr.Textbox(visible=True), | |
gr.Button(visible=True), | |
gr.Button(visible=True), | |
gr.CheckboxGroup( | |
choices=columns, value=suggested_text_columns, visible=True | |
), | |
gr.Dropdown(visible=True), | |
gr.Checkbox(visible=True), | |
gr.Button(visible=True), | |
gr.Dataframe(value=df, visible=True), | |
) | |
except Exception as e: | |
return ( | |
[], | |
gr.CheckboxGroup(choices=[]), | |
gr.CheckboxGroup(choices=[], visible=False), | |
gr.Textbox(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.CheckboxGroup(choices=[], visible=False), | |
gr.Dropdown(visible=False), | |
gr.Checkbox(visible=False), | |
gr.Button(visible=False), | |
gr.Dataframe(visible=False), | |
) | |
# Function to add a new category | |
def add_new_category(current_categories, new_category): | |
if not new_category or new_category.strip() == "": | |
return current_categories | |
new_categories = current_categories + [new_category.strip()] | |
return gr.CheckboxGroup(choices=new_categories, value=new_categories) | |
# Function to update categories textbox | |
def update_categories_textbox(selected_categories): | |
return ", ".join(selected_categories) | |
# Function to show results after processing | |
def show_results(df, validation_report): | |
"""Show the results after processing""" | |
if df is None: | |
return ( | |
gr.Row(visible=False), | |
gr.File(visible=False), | |
gr.File(visible=False), | |
gr.Dataframe(visible=False), | |
) | |
# Export to both formats | |
csv_path = export_results(df, "csv") | |
excel_path = export_results(df, "excel") | |
return ( | |
gr.Row(visible=True), | |
gr.File(value=csv_path, visible=True), | |
gr.File(value=excel_path, visible=True), | |
gr.Dataframe(value=df, visible=True), | |
) | |
# Function to suggest a new category | |
async def suggest_new_category(file, current_categories, text_columns): | |
if not file or not text_columns: | |
return gr.CheckboxGroup( | |
choices=current_categories, value=current_categories | |
) | |
try: | |
df = load_data(file.name) | |
sample_texts = get_sample_texts(df, text_columns) | |
if client: | |
classifier = LLMClassifier(client=client) | |
new_categories = await classifier.suggest_categories_from_texts( | |
sample_texts, current_categories | |
) | |
return gr.CheckboxGroup( | |
choices=new_categories, value=new_categories | |
) | |
return gr.CheckboxGroup( | |
choices=current_categories, value=current_categories | |
) | |
except Exception as e: | |
return gr.CheckboxGroup( | |
choices=current_categories, value=current_categories | |
) | |
# Function to handle export and show download button | |
def handle_export(df, format_type): | |
if df is None: | |
return gr.File(visible=False) | |
file_path = export_results(df, format_type) | |
return gr.File(value=file_path, visible=True) | |
# Connect functions | |
load_categories_button.click( | |
load_file_and_suggest_categories, | |
inputs=[file_input], | |
outputs=[ | |
available_columns, | |
text_column, | |
suggested_categories, | |
new_category, | |
add_category_button, | |
suggest_category_button, | |
text_column, | |
classifier_type, | |
show_explanations, | |
process_button, | |
original_df, | |
], | |
) | |
add_category_button.click( | |
add_new_category, | |
inputs=[suggested_categories, new_category], | |
outputs=[suggested_categories], | |
) | |
suggested_categories.change( | |
update_categories_textbox, | |
inputs=[suggested_categories], | |
outputs=[categories], | |
) | |
suggest_category_button.click( | |
suggest_new_category, | |
inputs=[file_input, suggested_categories, text_column], | |
outputs=[suggested_categories], | |
) | |
process_button.click( | |
lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df] | |
).then( | |
process_file_async, | |
inputs=[ | |
file_input, | |
text_column, | |
categories, | |
classifier_type, | |
show_explanations, | |
], | |
outputs=[results_df, validation_output], | |
).then( | |
show_results, | |
inputs=[results_df, validation_output], | |
outputs=[results_row, csv_download, excel_download, results_df], | |
).then( | |
visualize_results, inputs=[results_df, text_column], outputs=[visualization] | |
).then( | |
lambda x: gr.Button(visible=True), inputs=[], outputs=[improve_button] | |
) | |
improve_button.click( | |
improve_classification, | |
inputs=[ | |
results_df, | |
validation_output, | |
text_column, | |
categories, | |
classifier_type, | |
show_explanations, | |
file_input, | |
], | |
outputs=[ | |
results_df, | |
validation_output, | |
improve_button, | |
suggested_categories, | |
], | |
).then( | |
show_results, | |
inputs=[results_df, validation_output], | |
outputs=[results_row, csv_download, excel_download, results_df], | |
).then( | |
visualize_results, inputs=[results_df, text_column], outputs=[visualization] | |
) | |
def create_example_data(): | |
"""Create example data for demonstration""" | |
from utils import create_example_file | |
example_path = create_example_file() | |
return f"Example file created at: {example_path}" | |
if __name__ == "__main__": | |
# Create examples directory and sample file if it doesn't exist | |
if not os.path.exists("examples"): | |
create_example_data() | |
# Launch the Gradio app | |
demo.launch() | |