jzou19950715's picture
Update app.py
57afec9 verified
raw
history blame
11.4 kB
import os
import requests
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
import tempfile
import shutil
# Create a temporary directory for plot files
TEMP_DIR = tempfile.mkdtemp()
def cleanup_temp_files():
"""Clean up temporary files when the application exits"""
try:
shutil.rmtree(TEMP_DIR)
except Exception as e:
print(f"Error cleaning up temporary files: {e}")
def get_temp_path(filename):
"""Generate a path for temporary files"""
return os.path.join(TEMP_DIR, filename)
def call_gpt4o_mini(api_key, context, user_prompt):
"""Enhanced GPT-4o-mini call with better error handling"""
if not api_key or not api_key.startswith('sk-'):
return "Please provide a valid API key (should start with 'sk-')"
try:
url = os.getenv('GPT4O_MINI_API_URL', 'https://api.openai.com/v1/chat/completions')
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
messages = [
{"role": "system", "content": "You are a data analysis assistant. Analyze the provided data and context."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {user_prompt}"}
]
payload = {
"model": "gpt-4",
"messages": messages,
"max_tokens": 500,
"temperature": 0.7
}
response = requests.post(url, json=payload, headers=headers, timeout=10)
if response.status_code == 401:
return "Invalid API key. Please check your credentials."
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
return f"API Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
def extended_analysis(df):
"""Perform advanced analysis with improved error handling"""
output_paths = []
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
try:
# 1. Correlation Heatmap
if len(numeric_cols) > 1:
plt.figure(figsize=(12, 8))
corr = df[numeric_cols].corr()
mask = np.triu(np.ones_like(corr, dtype=bool))
sns.heatmap(corr, mask=mask, annot=True, cmap="coolwarm", fmt=".2f",
square=True, linewidths=.5)
plt.title("Correlation Heatmap of Numeric Features")
plt.tight_layout()
heatmap_path = get_temp_path("heatmap.png")
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths.append(heatmap_path)
# 2. Career Distribution
if "Career" in df.columns:
plt.figure(figsize=(12, 6))
career_counts = df["Career"].value_counts()
sns.barplot(x=career_counts.index, y=career_counts.values)
plt.title("Distribution of Careers")
plt.xticks(rotation=45, ha='right')
plt.xlabel("Career")
plt.ylabel("Count")
plt.tight_layout()
barplot_path = get_temp_path("career_distribution.png")
plt.savefig(barplot_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths.append(barplot_path)
# 3. Box Plots for Scores
score_columns = [col for col in df.columns if 'score' in col.lower() or 'aptitude' in col.lower()]
if score_columns:
plt.figure(figsize=(15, 8))
df[score_columns].boxplot()
plt.title("Distribution of Scores and Aptitudes")
plt.xticks(rotation=45, ha='right')
plt.ylabel("Score")
plt.grid(True, alpha=0.3)
plt.tight_layout()
boxplot_path = get_temp_path("scores_distribution.png")
plt.savefig(boxplot_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths.append(boxplot_path)
# 4. Machine Learning Analysis
if "Career" in df.columns and len(numeric_cols) > 0:
le = LabelEncoder()
df["Career_encoded"] = le.fit_transform(df["Career"])
X = df[numeric_cols].fillna(df[numeric_cols].mean())
y = df["Career_encoded"]
if len(np.unique(y)) > 1:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
# Feature importance visualization
plt.figure(figsize=(10, 6))
importance = pd.DataFrame({
'feature': numeric_cols,
'importance': np.abs(model.coef_[0])
}).sort_values('importance', ascending=True)
sns.barplot(data=importance, x='importance', y='feature')
plt.title("Feature Importance in Career Prediction")
plt.tight_layout()
importance_path = get_temp_path("feature_importance.png")
plt.savefig(importance_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths.append(importance_path)
return output_paths, f"Model accuracy: {score:.2f}"
else:
return output_paths, "Insufficient unique careers for classification"
return output_paths, "Analysis completed successfully"
except Exception as e:
return output_paths, f"Error during analysis: {str(e)}"
class ChatHistory:
"""Simple chat history manager without Gradio component inheritance"""
def __init__(self):
self.messages = []
self.data_summary = ""
self.current_df = None
def add_message(self, role, content):
self.messages.append({"role": role, "content": content})
if len(self.messages) > 10: # Keep last 10 messages
self.messages.pop(0)
def get_context(self):
return "\n".join([f"{m['role']}: {m['content']}" for m in self.messages[-5:]])
def set_data_summary(self, summary):
self.data_summary = summary
def get_full_context(self):
return f"Data Summary:\n{self.data_summary}\n\nChat History:\n{self.get_context()}"
def clear(self):
self.messages = []
self.data_summary = ""
self.current_df = None
def analyze_and_visualize(file, message, history, api_key, chat_history):
"""Main function for data analysis with improved state management"""
if not file and chat_history.current_df is None:
return history + [(message, "Please upload a CSV file first.")], chat_history
try:
# Load data if new file uploaded
if file:
df = pd.read_csv(file.name)
chat_history.current_df = df
else:
df = chat_history.current_df
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
summary = (
f"πŸ“Š Data Summary:\n"
f"β€’ Rows: {df.shape[0]}, Columns: {df.shape[1]}\n"
f"β€’ Numeric columns: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
f"β€’ Categorical columns: {', '.join(cat_cols) if cat_cols else 'None'}\n"
f"\nDescriptive Statistics:\n{df[numeric_cols].describe().round(2).to_string()}\n"
)
# Update chat history with new data summary if file was uploaded
if file:
chat_history.set_data_summary(summary)
# Add user message to context
chat_history.add_message("user", message)
# Get GPT-4o-mini insights if API key provided
if api_key:
gpt_response = call_gpt4o_mini(
api_key,
chat_history.get_full_context(),
message
)
chat_history.add_message("assistant", gpt_response)
response_text = f"{summary}\n\nπŸ€– AI Insights:\n{gpt_response}"
else:
response_text = f"{summary}\n\nNote: Add an API key for AI-powered insights."
# Generate visualizations based on user message
viz_triggers = ["visualize", "plot", "show", "graph", "analyze", "distribution"]
if any(trigger in message.lower() for trigger in viz_triggers):
analysis_paths, analysis_info = extended_analysis(df)
if analysis_info:
response_text += f"\n\nπŸ“ˆ Analysis Results:\n{analysis_info}"
chat_content = [(message, response_text)]
for path in analysis_paths:
chat_content.append((None, (path,)))
return history + chat_content, chat_history
return history + [(message, response_text)], chat_history
except Exception as e:
error_msg = f"Error processing request: {str(e)}"
chat_history.add_message("system", error_msg)
return history + [(message, error_msg)], chat_history
def create_demo():
chat_history = ChatHistory()
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Interactive Data Analysis Assistant")
gr.Markdown("""
**Instructions**:
1. (Optional) Enter your API key for AI-powered insights
2. Upload your CSV file
3. Ask questions about your data or request visualizations
4. Type keywords like 'visualize', 'plot', 'analyze' to generate charts
""")
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="API Key (Optional)",
placeholder="Enter your API key here",
type="password"
)
file_input = gr.File(label="Upload CSV", file_types=[".csv"])
msg = gr.Textbox(
label="Ask about your data",
placeholder="e.g., 'Show me the distribution of scores' or 'What insights can you find?'"
)
with gr.Row():
send_btn = gr.Button("Send")
clear_btn = gr.Button("Clear Chat")
with gr.Column():
chatbot = gr.Chatbot(height=600)
msg.submit(
fn=analyze_and_visualize,
inputs=[file_input, msg, chatbot, api_key, chat_history],
outputs=[chatbot, chat_history]
).then(lambda: "", None, [msg])
send_btn.click(
fn=analyze_and_visualize,
inputs=[file_input, msg, chatbot, api_key, chat_history],
outputs=[chatbot, chat_history]
).then(lambda: "", None, [msg])
clear_btn.click(
fn=lambda: ([], ChatHistory()),
outputs=[chatbot, chat_history]
)
demo.queue()
return demo
if __name__ == "__main__":
try:
demo = create_demo()
demo.launch()
finally:
cleanup_temp_files()