jzou19950715's picture
Update app.py
cedb0a7 verified
raw
history blame
6.5 kB
import os
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional
import openai
from dataclasses import dataclass
import plotly.express as px
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import statsmodels.api as sm
# System prompt for data analysis
DATA_ANALYSIS_PROMPT = """
<DataScienceExpertFramework version="1.0">
<Identity>
<Description>You are an expert data scientist who combines technical precision with clear insights.</Description>
</Identity>
<CoreCapabilities>
<Analysis>
<Capability>Statistical analysis and hypothesis testing</Capability>
<Capability>Pattern recognition and insights</Capability>
<Capability>Data visualization recommendations</Capability>
</Analysis>
</CoreCapabilities>
<AnalysisApproach>
<Step>Assess data quality and structure</Step>
<Step>Identify key patterns and relationships</Step>
<Step>Perform statistical analysis</Step>
<Step>Generate visualizations</Step>
<Step>Provide actionable insights</Step>
</AnalysisApproach>
</DataScienceExpertFramework>
"""
def format_stats_results(results: Dict) -> str:
"""Format statistical results for display"""
formatted = []
for test_name, result in results.items():
if "normality" in test_name:
formatted.append(f"- {test_name}: {'Normal' if result['is_normal'] else 'Non-normal'} "
f"(p={result['p_value']:.4f})")
elif "correlation" in test_name:
formatted.append(f"- {test_name}: {result['correlation']:.4f} "
f"(p={result['p_value']:.4f})")
return "\n".join(formatted)
def analyze_data(df: pd.DataFrame) -> Dict:
"""Analyze dataframe and return statistics"""
analysis = {
"shape": df.shape,
"dtypes": df.dtypes.to_dict(),
"missing": df.isnull().sum().to_dict(),
"numeric_summary": df.describe().to_dict(),
"correlations": {}
}
# Calculate correlations for numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) >= 2:
corr_matrix = df[numeric_cols].corr()
analysis["correlations"] = corr_matrix.to_dict()
return analysis
def create_visualizations(df: pd.DataFrame, save_dir: str = "figures") -> List[str]:
"""Create and save visualizations"""
os.makedirs(save_dir, exist_ok=True)
paths = []
# Correlation heatmap
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) >= 2:
plt.figure(figsize=(10, 8))
sns.heatmap(df[numeric_cols].corr(), annot=True, cmap='coolwarm')
plt.title("Correlation Heatmap")
path = os.path.join(save_dir, "correlation_heatmap.png")
plt.savefig(path)
plt.close()
paths.append(path)
# Distribution plots for numeric columns
for col in numeric_cols[:5]: # Limit to first 5 columns
plt.figure(figsize=(10, 6))
sns.histplot(df[col], kde=True)
plt.title(f"Distribution of {col}")
path = os.path.join(save_dir, f"dist_{col}.png")
plt.savefig(path)
plt.close()
paths.append(path)
return paths
def chat_with_data_scientist(message: str, history: List, api_key: str, df: Optional[pd.DataFrame] = None) -> List:
"""Chat with GPT-4o-mini about data analysis"""
if not api_key:
return history + [
("Please provide an API key to continue.", None)
]
if df is None:
return history + [
("Please upload a CSV file to analyze.", None)
]
try:
client = openai.OpenAI(api_key=api_key)
# Create analysis summary
analysis = analyze_data(df)
analysis_text = f"""
Dataset Shape: {analysis['shape']}
Missing Values: {sum(analysis['missing'].values())}
Numeric Columns: {list(analysis['numeric_summary'].keys())}
"""
messages = [
{"role": "system", "content": DATA_ANALYSIS_PROMPT},
{"role": "system", "content": f"Analysis Context:\n{analysis_text}"},
{"role": "user", "content": message}
]
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
max_tokens=500
)
return history + [
(message, response.choices[0].message.content)
]
except Exception as e:
return history + [
(message, f"Error: {str(e)}")
]
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ”¬ Data Science Expert")
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="GPT-4o-mini API Key",
placeholder="sk-...",
type="password"
)
file_input = gr.File(
label="Upload CSV file",
file_types=[".csv"]
)
system_prompt = gr.Textbox(
label="System Prompt",
value=DATA_ANALYSIS_PROMPT,
lines=5
)
with gr.Column():
chat = gr.Chatbot(label="Analysis Chat")
msg = gr.Textbox(
label="Ask about your data",
placeholder="What insights can you find in this dataset?"
)
clear = gr.Button("Clear")
# Store DataFrame in state
df_state = gr.State(None)
def process_file(file):
if file is None:
return None
return pd.read_csv(file.name)
file_input.change(
process_file,
inputs=[file_input],
outputs=[df_state]
)
msg.submit(
chat_with_data_scientist,
inputs=[msg, chat, api_key, df_state],
outputs=[chat]
)
clear.click(lambda: None, None, chat)
return demo
demo = create_demo()
if __name__ == "__main__":
demo.launch()
else:
demo.launch(show_api=False)