import os import gradio as gr import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go import seaborn as sns import matplotlib.pyplot as plt from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from transformers import Tool, ReactCodeAgent, HfApiEngine import openai from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.model_selection import train_test_split import statsmodels.api as sm import json import base64 import io # Configuration class for agent settings @dataclass class AgentConfig: """Configuration for the data science agent""" system_prompt: str = """ Expert Data Scientist and ML Engineer Statistical Analysis Machine Learning Data Visualization Feature Engineering Time Series Analysis Data Cleaning Feature Engineering Preprocessing Statistical Testing Pattern Recognition Correlation Analysis Model Selection Training Evaluation EDA Plots Statistical Plots Model Performance Plots Clear Explanations Statistical Evidence Visual Support Actionable Insights """ max_iterations: int = 10 temperature: float = 0.7 model_name: str = "gpt-4o-mini" # Data Analysis State class @dataclass class AnalysisState: """Maintains state for ongoing analysis""" df: Optional[pd.DataFrame] = None current_analysis: Dict = None visualizations: List[Dict] = None model_results: Dict = None error_log: List[str] = None def clear(self): self.df = None self.current_analysis = None self.visualizations = None self.model_results = None self.error_log = [] def log_error(self, error: str): if self.error_log is None: self.error_log = [] self.error_log.append(error) # Helper functions for data processing def process_uploaded_file(file) -> Tuple[Optional[pd.DataFrame], Dict]: """Process uploaded file and return DataFrame with info""" try: if file.name.endswith('.csv'): df = pd.read_csv(file.name) elif file.name.endswith('.xlsx'): df = pd.read_excel(file.name) elif file.name.endswith('.json'): df = pd.read_json(file.name) else: return None, {"error": "Unsupported file format"} info = { "shape": df.shape, "columns": list(df.columns), "dtypes": df.dtypes.to_dict(), "missing_values": df.isnull().sum().to_dict(), "numeric_columns": list(df.select_dtypes(include=[np.number]).columns), "categorical_columns": list(df.select_dtypes(exclude=[np.number]).columns) } return df, info except Exception as e: return None, {"error": str(e)} def create_visualization(data: pd.DataFrame, viz_type: str, params: Dict) -> Optional[Dict]: """Create visualization based on type and parameters""" try: if viz_type == "scatter": fig = px.scatter( data, x=params["x"], y=params["y"], color=params.get("color"), title=params.get("title", "Scatter Plot") ) elif viz_type == "histogram": fig = px.histogram( data, x=params["x"], nbins=params.get("nbins", 30), title=params.get("title", "Distribution") ) elif viz_type == "line": fig = px.line( data, x=params["x"], y=params["y"], title=params.get("title", "Line Plot") ) elif viz_type == "heatmap": numeric_cols = data.select_dtypes(include=[np.number]).columns corr = data[numeric_cols].corr() fig = px.imshow( corr, labels=dict(color="Correlation"), title=params.get("title", "Correlation Heatmap") ) else: return None return fig.to_dict() except Exception as e: return {"error": str(e)} class ChatInterface: """Manages the chat interface and message handling""" def __init__(self, agent_config: AgentConfig): self.config = agent_config self.history = [] self.agent = self._create_agent() def _create_agent(self) -> ReactCodeAgent: """Initialize the agent with tools""" tools = self._get_tools() llm_engine = HfApiEngine() return ReactCodeAgent( tools=tools, llm_engine=llm_engine, max_iterations=self.config.max_iterations ) def _get_tools(self) -> List[Tool]: """Get list of available tools""" # Import tools from our tools.py from tools import ( DataPreprocessingTool, StatisticalAnalysisTool, VisualizationTool, MLModelTool, TimeSeriesAnalysisTool ) return [ DataPreprocessingTool(), StatisticalAnalysisTool(), VisualizationTool(), MLModelTool(), TimeSeriesAnalysisTool() ] def process_message(self, message: str, analysis_state: AnalysisState) -> Tuple[List, Any]: """Process a message and return updated chat history and results""" try: if analysis_state.df is None: return self.history + [(message, "Please upload a data file first.")], None # Prepare context for the agent context = { "data_info": { "shape": analysis_state.df.shape, "columns": list(analysis_state.df.columns), "dtypes": analysis_state.df.dtypes.to_dict() }, "current_analysis": analysis_state.current_analysis, "available_tools": [tool.name for tool in self._get_tools()] } # Run agent response = self.agent.run( f"Context: {json.dumps(context)}\nUser request: {message}" ) self.history.append((message, response)) return self.history, response except Exception as e: error_msg = f"Error processing message: {str(e)}" analysis_state.log_error(error_msg) return self.history + [(message, error_msg)], None def create_demo(): # Initialize configuration and state config = AgentConfig() analysis_state = AnalysisState() chat_interface = ChatInterface(config) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔬 Advanced Data Science Agent") with gr.Row(): with gr.Column(scale=1): api_key = gr.Textbox( label="API Key (GPT-4o-mini)", type="password", placeholder="sk-..." ) file_input = gr.File( label="Upload Data", file_types=[".csv", ".xlsx", ".json"] ) with gr.Accordion("Analysis Settings", open=False): analysis_type = gr.Radio( choices=[ "Exploratory Analysis", "Statistical Analysis", "Machine Learning", "Time Series Analysis" ], label="Analysis Type", value="Exploratory Analysis" ) visualization_type = gr.Dropdown( choices=[ "Automatic", "Scatter Plots", "Distributions", "Correlations", "Time Series" ], label="Visualization Type", value="Automatic" ) model_params = gr.JSON( label="Model Parameters", value={ "test_size": 0.2, "n_estimators": 100, "handle_outliers": True } ) with gr.Accordion("System Settings", open=False): system_prompt = gr.Textbox( label="System Prompt", value=config.system_prompt, lines=10 ) max_iterations = gr.Slider( minimum=1, maximum=20, value=config.max_iterations, step=1, label="Max Iterations" ) with gr.Column(scale=2): # Chat interface chatbot = gr.Chatbot( label="Analysis Chat", height=400 ) with gr.Row(): text_input = gr.Textbox( label="Ask about your data", placeholder="What would you like to analyze?", lines=2 ) submit_btn = gr.Button("Analyze", variant="primary") with gr.Row(): clear_btn = gr.Button("Clear Chat") example_btn = gr.Button("Load Example") # Output displays with gr.Accordion("Visualization", open=True): plot_output = gr.Plot(label="Generated Plots") with gr.Accordion("Analysis Results", open=True): results_json = gr.JSON(label="Detailed Results") with gr.Accordion("Error Log", open=False): error_output = gr.Textbox(label="Errors", lines=3) # Event handlers def handle_file_upload(file): df, info = process_uploaded_file(file) if df is not None: analysis_state.df = df analysis_state.current_analysis = info return info, None return {"error": "Failed to load file"}, "Failed to load file" def handle_analysis(message, chat_history): history, response = chat_interface.process_message(message, analysis_state) return history def handle_clear(): analysis_state.clear() chat_interface.history = [] return None, None, None, None, None def load_example_data(): import sklearn.datasets data = sklearn.datasets.load_diabetes() df = pd.DataFrame(data.data, columns=data.feature_names) df['target'] = data.target analysis_state.df = df analysis_state.current_analysis = { "shape": df.shape, "columns": list(df.columns), "dtypes": df.dtypes.to_dict() } return analysis_state.current_analysis, None # Connect event handlers file_input.change( handle_file_upload, inputs=[file_input], outputs=[results_json, error_output] ) submit_btn.click( handle_analysis, inputs=[text_input, chatbot], outputs=[chatbot] ) text_input.submit( handle_analysis, inputs=[text_input, chatbot], outputs=[chatbot] ) clear_btn.click( handle_clear, outputs=[chatbot, plot_output, results_json, error_output, file_input] ) example_btn.click( load_example_data, outputs=[results_json, error_output] ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=True) else: demo = create_demo() demo.launch(show_api=False)