from transformers import Tool, ReactCodeAgent, HfApiEngine import gradio as gr import pandas as pd import numpy as np import plotly.express as px import plotly.graph_objects as go from typing import Dict, List, Optional import openai import seaborn as sns import matplotlib.pyplot as plt import io import base64 # Custom Tools for Data Analysis class DataVisualizationTool(Tool): name = "data_visualizer" description = """Creates various types of visualizations from data: - Correlation heatmaps - Distribution plots - Scatter plots - Time series plots Returns the plots as base64 encoded images.""" inputs = { "data": { "type": "dict", "description": "DataFrame as dictionary" }, "plot_type": { "type": "string", "description": "Type of plot to create: 'heatmap', 'distribution', 'scatter'" }, "columns": { "type": "list", "description": "List of columns to plot" } } output_type = "string" # base64 encoded image def forward(self, data: Dict, plot_type: str, columns: List[str]) -> str: df = pd.DataFrame(data) plt.figure(figsize=(10, 6)) if plot_type == "heatmap": sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm') plt.title("Correlation Heatmap") elif plot_type == "distribution": for col in columns: sns.histplot(df[col], kde=True, label=col) plt.title("Distribution Plot") plt.legend() elif plot_type == "scatter": if len(columns) >= 2: sns.scatterplot(data=df, x=columns[0], y=columns[1]) plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}") # Convert plot to base64 buf = io.BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) return base64.b64encode(buf.read()).decode('utf-8') class DataAnalysisTool(Tool): name = "data_analyzer" description = """Performs statistical analysis on data: - Basic statistics (mean, median, std) - Correlation analysis - Missing value analysis - Outlier detection""" inputs = { "data": { "type": "dict", "description": "DataFrame as dictionary" }, "analysis_type": { "type": "string", "description": "Type of analysis: 'basic', 'correlation', 'missing', 'outliers'" }, "columns": { "type": "list", "description": "List of columns to analyze" } } output_type = "dict" def forward(self, data: Dict, analysis_type: str, columns: List[str]) -> Dict: df = pd.DataFrame(data) selected_cols = [col for col in columns if col in df.columns] if analysis_type == "basic": return { "statistics": df[selected_cols].describe().to_dict(), "skew": df[selected_cols].skew().to_dict(), "kurtosis": df[selected_cols].kurtosis().to_dict() } elif analysis_type == "correlation": numeric_cols = df[selected_cols].select_dtypes(include=[np.number]) return { "correlation": numeric_cols.corr().to_dict(), "covariance": numeric_cols.cov().to_dict() } elif analysis_type == "missing": return { "missing_counts": df[selected_cols].isnull().sum().to_dict(), "missing_percentages": (df[selected_cols].isnull().mean() * 100).to_dict() } elif analysis_type == "outliers": outliers = {} for col in selected_cols: if df[col].dtype in [np.float64, np.int64]: Q1 = df[col].quantile(0.25) Q3 = df[col].quantile(0.75) IQR = Q3 - Q1 outliers[col] = { "outliers_count": len(df[(df[col] < Q1 - 1.5 * IQR) | (df[col] > Q3 + 1.5 * IQR)]), "lower_bound": Q1 - 1.5 * IQR, "upper_bound": Q3 + 1.5 * IQR } return {"outliers": outliers} def create_demo(): # Initialize tools viz_tool = DataVisualizationTool() analysis_tool = DataAnalysisTool() # Create agent with tools llm_engine = HfApiEngine() # Uses default model agent = ReactCodeAgent( tools=[viz_tool, analysis_tool], llm_engine=llm_engine ) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔬 Advanced Data Analysis Agent") with gr.Row(): with gr.Column(): api_key = gr.Textbox( label="OpenAI API Key", type="password", placeholder="sk-..." ) file_input = gr.File( label="Upload CSV", file_types=[".csv"] ) with gr.Accordion("Advanced Settings", open=False): system_prompt = gr.Textbox( label="System Prompt", value="""You are a data science expert. Analyze the data and create visualizations to help understand patterns and insights.""", lines=3 ) with gr.Column(): chat = gr.Chatbot(label="Analysis Chat") msg = gr.Textbox( label="Ask about your data", placeholder="What insights can you find in this dataset?" ) clear = gr.Button("Clear") # State for storing the DataFrame df_state = gr.State(None) def process_file(file): if file is None: return None return pd.read_csv(file.name) def process_message(message, chat_history, api_key, df): if df is None: return chat_history + [(message, "Please upload a CSV file first.")] try: # Convert DataFrame to dict for tools data_dict = df.to_dict() # Get all columns for potential analysis columns = list(df.columns) # Use agent to analyze and create visualizations response = agent.run( f"""Analyze this data: {message} Available columns: {columns} Use the data_analyzer and data_visualizer tools to create insights.""" ) return chat_history + [(message, response)] except Exception as e: return chat_history + [(message, f"Error: {str(e)}")] file_input.change( process_file, inputs=[file_input], outputs=[df_state] ) msg.submit( process_message, inputs=[msg, chat, api_key, df_state], outputs=[chat] ) clear.click(lambda: None, None, chat) return demo if __name__ == "__main__": demo = create_demo() demo.launch() else: demo.launch(show_api=False)