|
import os |
|
import requests |
|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import numpy as np |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.preprocessing import LabelEncoder |
|
import tempfile |
|
import shutil |
|
|
|
|
|
TEMP_DIR = tempfile.mkdtemp() |
|
|
|
def cleanup_temp_files(): |
|
"""Clean up temporary files when the application exits""" |
|
try: |
|
shutil.rmtree(TEMP_DIR) |
|
except Exception as e: |
|
print(f"Error cleaning up temporary files: {e}") |
|
|
|
def get_temp_path(filename): |
|
"""Generate a path for temporary files""" |
|
return os.path.join(TEMP_DIR, filename) |
|
|
|
def call_gpt4o_mini(api_key, context, user_prompt): |
|
"""Enhanced GPT-4o-mini call with better error handling""" |
|
if not api_key or not api_key.startswith('sk-'): |
|
return "Please provide a valid API key (should start with 'sk-')" |
|
|
|
try: |
|
url = os.getenv('GPT4O_MINI_API_URL', 'https://api.openai.com/v1/chat/completions') |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a data analysis assistant. Analyze the provided data and context."}, |
|
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {user_prompt}"} |
|
] |
|
|
|
payload = { |
|
"model": "gpt-4", |
|
"messages": messages, |
|
"max_tokens": 500, |
|
"temperature": 0.7 |
|
} |
|
|
|
response = requests.post(url, json=payload, headers=headers, timeout=10) |
|
if response.status_code == 401: |
|
return "Invalid API key. Please check your credentials." |
|
response.raise_for_status() |
|
return response.json()["choices"][0]["message"]["content"] |
|
except requests.exceptions.RequestException as e: |
|
return f"API Error: {str(e)}" |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def extended_analysis(df): |
|
"""Perform advanced analysis with improved error handling""" |
|
output_paths = [] |
|
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() |
|
|
|
try: |
|
|
|
if len(numeric_cols) > 1: |
|
plt.figure(figsize=(12, 8)) |
|
corr = df[numeric_cols].corr() |
|
mask = np.triu(np.ones_like(corr, dtype=bool)) |
|
sns.heatmap(corr, mask=mask, annot=True, cmap="coolwarm", fmt=".2f", |
|
square=True, linewidths=.5) |
|
plt.title("Correlation Heatmap of Numeric Features") |
|
plt.tight_layout() |
|
heatmap_path = get_temp_path("heatmap.png") |
|
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
output_paths.append(heatmap_path) |
|
|
|
|
|
if "Career" in df.columns: |
|
plt.figure(figsize=(12, 6)) |
|
career_counts = df["Career"].value_counts() |
|
sns.barplot(x=career_counts.index, y=career_counts.values) |
|
plt.title("Distribution of Careers") |
|
plt.xticks(rotation=45, ha='right') |
|
plt.xlabel("Career") |
|
plt.ylabel("Count") |
|
plt.tight_layout() |
|
barplot_path = get_temp_path("career_distribution.png") |
|
plt.savefig(barplot_path, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
output_paths.append(barplot_path) |
|
|
|
|
|
score_columns = [col for col in df.columns if 'score' in col.lower() or 'aptitude' in col.lower()] |
|
if score_columns: |
|
plt.figure(figsize=(15, 8)) |
|
df[score_columns].boxplot() |
|
plt.title("Distribution of Scores and Aptitudes") |
|
plt.xticks(rotation=45, ha='right') |
|
plt.ylabel("Score") |
|
plt.grid(True, alpha=0.3) |
|
plt.tight_layout() |
|
boxplot_path = get_temp_path("scores_distribution.png") |
|
plt.savefig(boxplot_path, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
output_paths.append(boxplot_path) |
|
|
|
|
|
if "Career" in df.columns and len(numeric_cols) > 0: |
|
le = LabelEncoder() |
|
df["Career_encoded"] = le.fit_transform(df["Career"]) |
|
X = df[numeric_cols].fillna(df[numeric_cols].mean()) |
|
y = df["Career_encoded"] |
|
|
|
if len(np.unique(y)) > 1: |
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y, test_size=0.2, random_state=42 |
|
) |
|
model = LogisticRegression(max_iter=1000) |
|
model.fit(X_train, y_train) |
|
score = model.score(X_test, y_test) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
importance = pd.DataFrame({ |
|
'feature': numeric_cols, |
|
'importance': np.abs(model.coef_[0]) |
|
}).sort_values('importance', ascending=True) |
|
|
|
sns.barplot(data=importance, x='importance', y='feature') |
|
plt.title("Feature Importance in Career Prediction") |
|
plt.tight_layout() |
|
importance_path = get_temp_path("feature_importance.png") |
|
plt.savefig(importance_path, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
output_paths.append(importance_path) |
|
|
|
return output_paths, f"Model accuracy: {score:.2f}" |
|
else: |
|
return output_paths, "Insufficient unique careers for classification" |
|
return output_paths, "Analysis completed successfully" |
|
|
|
except Exception as e: |
|
return output_paths, f"Error during analysis: {str(e)}" |
|
|
|
class ChatHistory: |
|
"""Simple chat history manager without Gradio component inheritance""" |
|
def __init__(self): |
|
self.messages = [] |
|
self.data_summary = "" |
|
self.current_df = None |
|
|
|
def add_message(self, role, content): |
|
self.messages.append({"role": role, "content": content}) |
|
if len(self.messages) > 10: |
|
self.messages.pop(0) |
|
|
|
def get_context(self): |
|
return "\n".join([f"{m['role']}: {m['content']}" for m in self.messages[-5:]]) |
|
|
|
def set_data_summary(self, summary): |
|
self.data_summary = summary |
|
|
|
def get_full_context(self): |
|
return f"Data Summary:\n{self.data_summary}\n\nChat History:\n{self.get_context()}" |
|
|
|
def clear(self): |
|
self.messages = [] |
|
self.data_summary = "" |
|
self.current_df = None |
|
|
|
def analyze_and_visualize(file, message, history, api_key, chat_history): |
|
"""Main function for data analysis with improved state management""" |
|
if not file and chat_history.current_df is None: |
|
return history + [(message, "Please upload a CSV file first.")], chat_history |
|
|
|
try: |
|
|
|
if file: |
|
df = pd.read_csv(file.name) |
|
chat_history.current_df = df |
|
else: |
|
df = chat_history.current_df |
|
|
|
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() |
|
cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist() |
|
|
|
summary = ( |
|
f"π Data Summary:\n" |
|
f"β’ Rows: {df.shape[0]}, Columns: {df.shape[1]}\n" |
|
f"β’ Numeric columns: {', '.join(numeric_cols) if numeric_cols else 'None'}\n" |
|
f"β’ Categorical columns: {', '.join(cat_cols) if cat_cols else 'None'}\n" |
|
f"\nDescriptive Statistics:\n{df[numeric_cols].describe().round(2).to_string()}\n" |
|
) |
|
|
|
|
|
if file: |
|
chat_history.set_data_summary(summary) |
|
|
|
|
|
chat_history.add_message("user", message) |
|
|
|
|
|
if api_key: |
|
gpt_response = call_gpt4o_mini( |
|
api_key, |
|
chat_history.get_full_context(), |
|
message |
|
) |
|
chat_history.add_message("assistant", gpt_response) |
|
response_text = f"{summary}\n\nπ€ AI Insights:\n{gpt_response}" |
|
else: |
|
response_text = f"{summary}\n\nNote: Add an API key for AI-powered insights." |
|
|
|
|
|
viz_triggers = ["visualize", "plot", "show", "graph", "analyze", "distribution"] |
|
if any(trigger in message.lower() for trigger in viz_triggers): |
|
analysis_paths, analysis_info = extended_analysis(df) |
|
if analysis_info: |
|
response_text += f"\n\nπ Analysis Results:\n{analysis_info}" |
|
|
|
chat_content = [(message, response_text)] |
|
for path in analysis_paths: |
|
chat_content.append((None, (path,))) |
|
return history + chat_content, chat_history |
|
|
|
return history + [(message, response_text)], chat_history |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing request: {str(e)}" |
|
chat_history.add_message("system", error_msg) |
|
return history + [(message, error_msg)], chat_history |
|
|
|
def create_demo(): |
|
chat_history = ChatHistory() |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Interactive Data Analysis Assistant") |
|
gr.Markdown(""" |
|
**Instructions**: |
|
1. (Optional) Enter your API key for AI-powered insights |
|
2. Upload your CSV file |
|
3. Ask questions about your data or request visualizations |
|
4. Type keywords like 'visualize', 'plot', 'analyze' to generate charts |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
api_key = gr.Textbox( |
|
label="API Key (Optional)", |
|
placeholder="Enter your API key here", |
|
type="password" |
|
) |
|
file_input = gr.File(label="Upload CSV", file_types=[".csv"]) |
|
msg = gr.Textbox( |
|
label="Ask about your data", |
|
placeholder="e.g., 'Show me the distribution of scores' or 'What insights can you find?'" |
|
) |
|
with gr.Row(): |
|
send_btn = gr.Button("Send") |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
with gr.Column(): |
|
chatbot = gr.Chatbot(height=600) |
|
|
|
msg.submit( |
|
fn=analyze_and_visualize, |
|
inputs=[file_input, msg, chatbot, api_key, chat_history], |
|
outputs=[chatbot, chat_history] |
|
).then(lambda: "", None, [msg]) |
|
|
|
send_btn.click( |
|
fn=analyze_and_visualize, |
|
inputs=[file_input, msg, chatbot, api_key, chat_history], |
|
outputs=[chatbot, chat_history] |
|
).then(lambda: "", None, [msg]) |
|
|
|
clear_btn.click( |
|
fn=lambda: ([], ChatHistory()), |
|
outputs=[chatbot, chat_history] |
|
) |
|
|
|
demo.queue() |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
try: |
|
demo = create_demo() |
|
demo.launch() |
|
finally: |
|
cleanup_temp_files() |