|
from transformers import Tool, ReactCodeAgent, HfApiEngine |
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from typing import Dict, List, Optional |
|
import openai |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import io |
|
import base64 |
|
|
|
|
|
class DataVisualizationTool(Tool): |
|
name = "data_visualizer" |
|
description = """Creates various types of visualizations from data: |
|
- Correlation heatmaps |
|
- Distribution plots |
|
- Scatter plots |
|
- Time series plots |
|
Returns the plots as base64 encoded images.""" |
|
|
|
inputs = { |
|
"data": { |
|
"type": "dict", |
|
"description": "DataFrame as dictionary" |
|
}, |
|
"plot_type": { |
|
"type": "string", |
|
"description": "Type of plot to create: 'heatmap', 'distribution', 'scatter'" |
|
}, |
|
"columns": { |
|
"type": "list", |
|
"description": "List of columns to plot" |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def forward(self, data: Dict, plot_type: str, columns: List[str]) -> str: |
|
df = pd.DataFrame(data) |
|
plt.figure(figsize=(10, 6)) |
|
|
|
if plot_type == "heatmap": |
|
sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm') |
|
plt.title("Correlation Heatmap") |
|
elif plot_type == "distribution": |
|
for col in columns: |
|
sns.histplot(df[col], kde=True, label=col) |
|
plt.title("Distribution Plot") |
|
plt.legend() |
|
elif plot_type == "scatter": |
|
if len(columns) >= 2: |
|
sns.scatterplot(data=df, x=columns[0], y=columns[1]) |
|
plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}") |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
return base64.b64encode(buf.read()).decode('utf-8') |
|
|
|
class DataAnalysisTool(Tool): |
|
name = "data_analyzer" |
|
description = """Performs statistical analysis on data: |
|
- Basic statistics (mean, median, std) |
|
- Correlation analysis |
|
- Missing value analysis |
|
- Outlier detection""" |
|
|
|
inputs = { |
|
"data": { |
|
"type": "dict", |
|
"description": "DataFrame as dictionary" |
|
}, |
|
"analysis_type": { |
|
"type": "string", |
|
"description": "Type of analysis: 'basic', 'correlation', 'missing', 'outliers'" |
|
}, |
|
"columns": { |
|
"type": "list", |
|
"description": "List of columns to analyze" |
|
} |
|
} |
|
output_type = "dict" |
|
|
|
def forward(self, data: Dict, analysis_type: str, columns: List[str]) -> Dict: |
|
df = pd.DataFrame(data) |
|
selected_cols = [col for col in columns if col in df.columns] |
|
|
|
if analysis_type == "basic": |
|
return { |
|
"statistics": df[selected_cols].describe().to_dict(), |
|
"skew": df[selected_cols].skew().to_dict(), |
|
"kurtosis": df[selected_cols].kurtosis().to_dict() |
|
} |
|
elif analysis_type == "correlation": |
|
numeric_cols = df[selected_cols].select_dtypes(include=[np.number]) |
|
return { |
|
"correlation": numeric_cols.corr().to_dict(), |
|
"covariance": numeric_cols.cov().to_dict() |
|
} |
|
elif analysis_type == "missing": |
|
return { |
|
"missing_counts": df[selected_cols].isnull().sum().to_dict(), |
|
"missing_percentages": (df[selected_cols].isnull().mean() * 100).to_dict() |
|
} |
|
elif analysis_type == "outliers": |
|
outliers = {} |
|
for col in selected_cols: |
|
if df[col].dtype in [np.float64, np.int64]: |
|
Q1 = df[col].quantile(0.25) |
|
Q3 = df[col].quantile(0.75) |
|
IQR = Q3 - Q1 |
|
outliers[col] = { |
|
"outliers_count": len(df[(df[col] < Q1 - 1.5 * IQR) | (df[col] > Q3 + 1.5 * IQR)]), |
|
"lower_bound": Q1 - 1.5 * IQR, |
|
"upper_bound": Q3 + 1.5 * IQR |
|
} |
|
return {"outliers": outliers} |
|
|
|
def create_demo(): |
|
|
|
viz_tool = DataVisualizationTool() |
|
analysis_tool = DataAnalysisTool() |
|
|
|
|
|
llm_engine = HfApiEngine() |
|
agent = ReactCodeAgent( |
|
tools=[viz_tool, analysis_tool], |
|
llm_engine=llm_engine |
|
) |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π¬ Advanced Data Analysis Agent") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
api_key = gr.Textbox( |
|
label="OpenAI API Key", |
|
type="password", |
|
placeholder="sk-..." |
|
) |
|
file_input = gr.File( |
|
label="Upload CSV", |
|
file_types=[".csv"] |
|
) |
|
with gr.Accordion("Advanced Settings", open=False): |
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
value="""You are a data science expert. Analyze the data and create |
|
visualizations to help understand patterns and insights.""", |
|
lines=3 |
|
) |
|
|
|
with gr.Column(): |
|
chat = gr.Chatbot(label="Analysis Chat") |
|
msg = gr.Textbox( |
|
label="Ask about your data", |
|
placeholder="What insights can you find in this dataset?" |
|
) |
|
clear = gr.Button("Clear") |
|
|
|
|
|
df_state = gr.State(None) |
|
|
|
def process_file(file): |
|
if file is None: |
|
return None |
|
return pd.read_csv(file.name) |
|
|
|
def process_message(message, chat_history, api_key, df): |
|
if df is None: |
|
return chat_history + [(message, "Please upload a CSV file first.")] |
|
|
|
try: |
|
|
|
data_dict = df.to_dict() |
|
|
|
|
|
columns = list(df.columns) |
|
|
|
|
|
response = agent.run( |
|
f"""Analyze this data: {message} |
|
Available columns: {columns} |
|
Use the data_analyzer and data_visualizer tools to create insights.""" |
|
) |
|
|
|
return chat_history + [(message, response)] |
|
|
|
except Exception as e: |
|
return chat_history + [(message, f"Error: {str(e)}")] |
|
|
|
file_input.change( |
|
process_file, |
|
inputs=[file_input], |
|
outputs=[df_state] |
|
) |
|
|
|
msg.submit( |
|
process_message, |
|
inputs=[msg, chat, api_key, df_state], |
|
outputs=[chat] |
|
) |
|
|
|
clear.click(lambda: None, None, chat) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_demo() |
|
demo.launch() |
|
else: |
|
demo.launch(show_api=False) |