|
import os |
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
from typing import Dict, List, Optional, Tuple, Any |
|
from dataclasses import dataclass |
|
from transformers import Tool, ReactCodeAgent, HfApiEngine |
|
import openai |
|
from sklearn.preprocessing import StandardScaler, LabelEncoder |
|
from sklearn.model_selection import train_test_split |
|
import statsmodels.api as sm |
|
import json |
|
import base64 |
|
import io |
|
|
|
|
|
@dataclass |
|
class AgentConfig: |
|
"""Configuration for the data science agent""" |
|
system_prompt: str = """ |
|
<DataScienceExpertFramework version="2.0"> |
|
<Identity> |
|
<Role>Expert Data Scientist and ML Engineer</Role> |
|
<Expertise> |
|
<Area>Statistical Analysis</Area> |
|
<Area>Machine Learning</Area> |
|
<Area>Data Visualization</Area> |
|
<Area>Feature Engineering</Area> |
|
<Area>Time Series Analysis</Area> |
|
</Expertise> |
|
</Identity> |
|
<Capabilities> |
|
<DataProcessing> |
|
<Task>Data Cleaning</Task> |
|
<Task>Feature Engineering</Task> |
|
<Task>Preprocessing</Task> |
|
</DataProcessing> |
|
<Analysis> |
|
<Task>Statistical Testing</Task> |
|
<Task>Pattern Recognition</Task> |
|
<Task>Correlation Analysis</Task> |
|
</Analysis> |
|
<MachineLearning> |
|
<Task>Model Selection</Task> |
|
<Task>Training</Task> |
|
<Task>Evaluation</Task> |
|
</MachineLearning> |
|
<Visualization> |
|
<Task>EDA Plots</Task> |
|
<Task>Statistical Plots</Task> |
|
<Task>Model Performance Plots</Task> |
|
</Visualization> |
|
</Capabilities> |
|
<OutputFormat> |
|
<Format>Clear Explanations</Format> |
|
<Format>Statistical Evidence</Format> |
|
<Format>Visual Support</Format> |
|
<Format>Actionable Insights</Format> |
|
</OutputFormat> |
|
</DataScienceExpertFramework> |
|
""" |
|
max_iterations: int = 10 |
|
temperature: float = 0.7 |
|
model_name: str = "gpt-4o-mini" |
|
|
|
|
|
@dataclass |
|
class AnalysisState: |
|
"""Maintains state for ongoing analysis""" |
|
df: Optional[pd.DataFrame] = None |
|
current_analysis: Dict = None |
|
visualizations: List[Dict] = None |
|
model_results: Dict = None |
|
error_log: List[str] = None |
|
|
|
def clear(self): |
|
self.df = None |
|
self.current_analysis = None |
|
self.visualizations = None |
|
self.model_results = None |
|
self.error_log = [] |
|
|
|
def log_error(self, error: str): |
|
if self.error_log is None: |
|
self.error_log = [] |
|
self.error_log.append(error) |
|
|
|
|
|
def process_uploaded_file(file) -> Tuple[Optional[pd.DataFrame], Dict]: |
|
"""Process uploaded file and return DataFrame with info""" |
|
try: |
|
if file.name.endswith('.csv'): |
|
df = pd.read_csv(file.name) |
|
elif file.name.endswith('.xlsx'): |
|
df = pd.read_excel(file.name) |
|
elif file.name.endswith('.json'): |
|
df = pd.read_json(file.name) |
|
else: |
|
return None, {"error": "Unsupported file format"} |
|
|
|
info = { |
|
"shape": df.shape, |
|
"columns": list(df.columns), |
|
"dtypes": df.dtypes.to_dict(), |
|
"missing_values": df.isnull().sum().to_dict(), |
|
"numeric_columns": list(df.select_dtypes(include=[np.number]).columns), |
|
"categorical_columns": list(df.select_dtypes(exclude=[np.number]).columns) |
|
} |
|
|
|
return df, info |
|
except Exception as e: |
|
return None, {"error": str(e)} |
|
|
|
def create_visualization(data: pd.DataFrame, viz_type: str, params: Dict) -> Optional[Dict]: |
|
"""Create visualization based on type and parameters""" |
|
try: |
|
if viz_type == "scatter": |
|
fig = px.scatter( |
|
data, |
|
x=params["x"], |
|
y=params["y"], |
|
color=params.get("color"), |
|
title=params.get("title", "Scatter Plot") |
|
) |
|
elif viz_type == "histogram": |
|
fig = px.histogram( |
|
data, |
|
x=params["x"], |
|
nbins=params.get("nbins", 30), |
|
title=params.get("title", "Distribution") |
|
) |
|
elif viz_type == "line": |
|
fig = px.line( |
|
data, |
|
x=params["x"], |
|
y=params["y"], |
|
title=params.get("title", "Line Plot") |
|
) |
|
elif viz_type == "heatmap": |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
corr = data[numeric_cols].corr() |
|
fig = px.imshow( |
|
corr, |
|
labels=dict(color="Correlation"), |
|
title=params.get("title", "Correlation Heatmap") |
|
) |
|
else: |
|
return None |
|
|
|
return fig.to_dict() |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
class ChatInterface: |
|
"""Manages the chat interface and message handling""" |
|
def __init__(self, agent_config: AgentConfig): |
|
self.config = agent_config |
|
self.history = [] |
|
self.agent = self._create_agent() |
|
|
|
def _create_agent(self) -> ReactCodeAgent: |
|
"""Initialize the agent with tools""" |
|
tools = self._get_tools() |
|
llm_engine = HfApiEngine() |
|
return ReactCodeAgent( |
|
tools=tools, |
|
llm_engine=llm_engine, |
|
max_iterations=self.config.max_iterations |
|
) |
|
|
|
def _get_tools(self) -> List[Tool]: |
|
"""Get list of available tools""" |
|
|
|
from tools import ( |
|
DataPreprocessingTool, |
|
StatisticalAnalysisTool, |
|
VisualizationTool, |
|
MLModelTool, |
|
TimeSeriesAnalysisTool |
|
) |
|
|
|
return [ |
|
DataPreprocessingTool(), |
|
StatisticalAnalysisTool(), |
|
VisualizationTool(), |
|
MLModelTool(), |
|
TimeSeriesAnalysisTool() |
|
] |
|
|
|
def process_message(self, message: str, analysis_state: AnalysisState) -> Tuple[List, Any]: |
|
"""Process a message and return updated chat history and results""" |
|
try: |
|
if analysis_state.df is None: |
|
return self.history + [(message, "Please upload a data file first.")], None |
|
|
|
|
|
context = { |
|
"data_info": { |
|
"shape": analysis_state.df.shape, |
|
"columns": list(analysis_state.df.columns), |
|
"dtypes": analysis_state.df.dtypes.to_dict() |
|
}, |
|
"current_analysis": analysis_state.current_analysis, |
|
"available_tools": [tool.name for tool in self._get_tools()] |
|
} |
|
|
|
|
|
response = self.agent.run( |
|
f"Context: {json.dumps(context)}\nUser request: {message}" |
|
) |
|
|
|
self.history.append((message, response)) |
|
return self.history, response |
|
except Exception as e: |
|
error_msg = f"Error processing message: {str(e)}" |
|
analysis_state.log_error(error_msg) |
|
return self.history + [(message, error_msg)], None |
|
|
|
def create_demo(): |
|
|
|
config = AgentConfig() |
|
analysis_state = AnalysisState() |
|
chat_interface = ChatInterface(config) |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# 🔬 Advanced Data Science Agent") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
api_key = gr.Textbox( |
|
label="API Key (GPT-4o-mini)", |
|
type="password", |
|
placeholder="sk-..." |
|
) |
|
|
|
file_input = gr.File( |
|
label="Upload Data", |
|
file_types=[".csv", ".xlsx", ".json"] |
|
) |
|
|
|
with gr.Accordion("Analysis Settings", open=False): |
|
analysis_type = gr.Radio( |
|
choices=[ |
|
"Exploratory Analysis", |
|
"Statistical Analysis", |
|
"Machine Learning", |
|
"Time Series Analysis" |
|
], |
|
label="Analysis Type", |
|
value="Exploratory Analysis" |
|
) |
|
|
|
visualization_type = gr.Dropdown( |
|
choices=[ |
|
"Automatic", |
|
"Scatter Plots", |
|
"Distributions", |
|
"Correlations", |
|
"Time Series" |
|
], |
|
label="Visualization Type", |
|
value="Automatic" |
|
) |
|
|
|
model_params = gr.JSON( |
|
label="Model Parameters", |
|
value={ |
|
"test_size": 0.2, |
|
"n_estimators": 100, |
|
"handle_outliers": True |
|
} |
|
) |
|
|
|
with gr.Accordion("System Settings", open=False): |
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
value=config.system_prompt, |
|
lines=10 |
|
) |
|
|
|
max_iterations = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
value=config.max_iterations, |
|
step=1, |
|
label="Max Iterations" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
|
|
chatbot = gr.Chatbot( |
|
label="Analysis Chat", |
|
height=400 |
|
) |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox( |
|
label="Ask about your data", |
|
placeholder="What would you like to analyze?", |
|
lines=2 |
|
) |
|
submit_btn = gr.Button("Analyze", variant="primary") |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("Clear Chat") |
|
example_btn = gr.Button("Load Example") |
|
|
|
|
|
with gr.Accordion("Visualization", open=True): |
|
plot_output = gr.Plot(label="Generated Plots") |
|
|
|
with gr.Accordion("Analysis Results", open=True): |
|
results_json = gr.JSON(label="Detailed Results") |
|
|
|
with gr.Accordion("Error Log", open=False): |
|
error_output = gr.Textbox(label="Errors", lines=3) |
|
|
|
|
|
def handle_file_upload(file): |
|
df, info = process_uploaded_file(file) |
|
if df is not None: |
|
analysis_state.df = df |
|
analysis_state.current_analysis = info |
|
return info, None |
|
return {"error": "Failed to load file"}, "Failed to load file" |
|
|
|
def handle_analysis(message, chat_history): |
|
history, response = chat_interface.process_message(message, analysis_state) |
|
return history |
|
|
|
def handle_clear(): |
|
analysis_state.clear() |
|
chat_interface.history = [] |
|
return None, None, None, None, None |
|
|
|
def load_example_data(): |
|
import sklearn.datasets |
|
data = sklearn.datasets.load_diabetes() |
|
df = pd.DataFrame(data.data, columns=data.feature_names) |
|
df['target'] = data.target |
|
|
|
analysis_state.df = df |
|
analysis_state.current_analysis = { |
|
"shape": df.shape, |
|
"columns": list(df.columns), |
|
"dtypes": df.dtypes.to_dict() |
|
} |
|
|
|
return analysis_state.current_analysis, None |
|
|
|
|
|
file_input.change( |
|
handle_file_upload, |
|
inputs=[file_input], |
|
outputs=[results_json, error_output] |
|
) |
|
|
|
submit_btn.click( |
|
handle_analysis, |
|
inputs=[text_input, chatbot], |
|
outputs=[chatbot] |
|
) |
|
|
|
text_input.submit( |
|
handle_analysis, |
|
inputs=[text_input, chatbot], |
|
outputs=[chatbot] |
|
) |
|
|
|
clear_btn.click( |
|
handle_clear, |
|
outputs=[chatbot, plot_output, results_json, error_output, file_input] |
|
) |
|
|
|
example_btn.click( |
|
load_example_data, |
|
outputs=[results_json, error_output] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.launch(share=True) |
|
else: |
|
demo = create_demo() |
|
demo.launch(show_api=False) |