jzou19950715's picture
Update app.py
ad9e004 verified
raw
history blame
7.47 kB
from transformers import Tool, ReactCodeAgent, HfApiEngine
import gradio as gr
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from typing import Dict, List, Optional
import openai
import seaborn as sns
import matplotlib.pyplot as plt
import io
import base64
# Custom Tools for Data Analysis
class DataVisualizationTool(Tool):
name = "data_visualizer"
description = """Creates various types of visualizations from data:
- Correlation heatmaps
- Distribution plots
- Scatter plots
- Time series plots
Returns the plots as base64 encoded images."""
inputs = {
"data": {
"type": "dict",
"description": "DataFrame as dictionary"
},
"plot_type": {
"type": "string",
"description": "Type of plot to create: 'heatmap', 'distribution', 'scatter'"
},
"columns": {
"type": "list",
"description": "List of columns to plot"
}
}
output_type = "string" # base64 encoded image
def forward(self, data: Dict, plot_type: str, columns: List[str]) -> str:
df = pd.DataFrame(data)
plt.figure(figsize=(10, 6))
if plot_type == "heatmap":
sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm')
plt.title("Correlation Heatmap")
elif plot_type == "distribution":
for col in columns:
sns.histplot(df[col], kde=True, label=col)
plt.title("Distribution Plot")
plt.legend()
elif plot_type == "scatter":
if len(columns) >= 2:
sns.scatterplot(data=df, x=columns[0], y=columns[1])
plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}")
# Convert plot to base64
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
return base64.b64encode(buf.read()).decode('utf-8')
class DataAnalysisTool(Tool):
name = "data_analyzer"
description = """Performs statistical analysis on data:
- Basic statistics (mean, median, std)
- Correlation analysis
- Missing value analysis
- Outlier detection"""
inputs = {
"data": {
"type": "dict",
"description": "DataFrame as dictionary"
},
"analysis_type": {
"type": "string",
"description": "Type of analysis: 'basic', 'correlation', 'missing', 'outliers'"
},
"columns": {
"type": "list",
"description": "List of columns to analyze"
}
}
output_type = "dict"
def forward(self, data: Dict, analysis_type: str, columns: List[str]) -> Dict:
df = pd.DataFrame(data)
selected_cols = [col for col in columns if col in df.columns]
if analysis_type == "basic":
return {
"statistics": df[selected_cols].describe().to_dict(),
"skew": df[selected_cols].skew().to_dict(),
"kurtosis": df[selected_cols].kurtosis().to_dict()
}
elif analysis_type == "correlation":
numeric_cols = df[selected_cols].select_dtypes(include=[np.number])
return {
"correlation": numeric_cols.corr().to_dict(),
"covariance": numeric_cols.cov().to_dict()
}
elif analysis_type == "missing":
return {
"missing_counts": df[selected_cols].isnull().sum().to_dict(),
"missing_percentages": (df[selected_cols].isnull().mean() * 100).to_dict()
}
elif analysis_type == "outliers":
outliers = {}
for col in selected_cols:
if df[col].dtype in [np.float64, np.int64]:
Q1 = df[col].quantile(0.25)
Q3 = df[col].quantile(0.75)
IQR = Q3 - Q1
outliers[col] = {
"outliers_count": len(df[(df[col] < Q1 - 1.5 * IQR) | (df[col] > Q3 + 1.5 * IQR)]),
"lower_bound": Q1 - 1.5 * IQR,
"upper_bound": Q3 + 1.5 * IQR
}
return {"outliers": outliers}
def create_demo():
# Initialize tools
viz_tool = DataVisualizationTool()
analysis_tool = DataAnalysisTool()
# Create agent with tools
llm_engine = HfApiEngine() # Uses default model
agent = ReactCodeAgent(
tools=[viz_tool, analysis_tool],
llm_engine=llm_engine
)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ”¬ Advanced Data Analysis Agent")
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="OpenAI API Key",
type="password",
placeholder="sk-..."
)
file_input = gr.File(
label="Upload CSV",
file_types=[".csv"]
)
with gr.Accordion("Advanced Settings", open=False):
system_prompt = gr.Textbox(
label="System Prompt",
value="""You are a data science expert. Analyze the data and create
visualizations to help understand patterns and insights.""",
lines=3
)
with gr.Column():
chat = gr.Chatbot(label="Analysis Chat")
msg = gr.Textbox(
label="Ask about your data",
placeholder="What insights can you find in this dataset?"
)
clear = gr.Button("Clear")
# State for storing the DataFrame
df_state = gr.State(None)
def process_file(file):
if file is None:
return None
return pd.read_csv(file.name)
def process_message(message, chat_history, api_key, df):
if df is None:
return chat_history + [(message, "Please upload a CSV file first.")]
try:
# Convert DataFrame to dict for tools
data_dict = df.to_dict()
# Get all columns for potential analysis
columns = list(df.columns)
# Use agent to analyze and create visualizations
response = agent.run(
f"""Analyze this data: {message}
Available columns: {columns}
Use the data_analyzer and data_visualizer tools to create insights."""
)
return chat_history + [(message, response)]
except Exception as e:
return chat_history + [(message, f"Error: {str(e)}")]
file_input.change(
process_file,
inputs=[file_input],
outputs=[df_state]
)
msg.submit(
process_message,
inputs=[msg, chat, api_key, df_state],
outputs=[chat]
)
clear.click(lambda: None, None, chat)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()
else:
demo.launch(show_api=False)