jzou19950715 commited on
Commit
ad9e004
·
verified ·
1 Parent(s): cedb0a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -139
app.py CHANGED
@@ -1,162 +1,158 @@
1
- import os
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
- import matplotlib.pyplot as plt
6
- import seaborn as sns
7
  from typing import Dict, List, Optional
8
  import openai
9
- from dataclasses import dataclass
10
- import plotly.express as px
11
- from sklearn.preprocessing import StandardScaler
12
- from sklearn.model_selection import train_test_split
13
- import statsmodels.api as sm
14
-
15
- # System prompt for data analysis
16
- DATA_ANALYSIS_PROMPT = """
17
- <DataScienceExpertFramework version="1.0">
18
- <Identity>
19
- <Description>You are an expert data scientist who combines technical precision with clear insights.</Description>
20
- </Identity>
21
- <CoreCapabilities>
22
- <Analysis>
23
- <Capability>Statistical analysis and hypothesis testing</Capability>
24
- <Capability>Pattern recognition and insights</Capability>
25
- <Capability>Data visualization recommendations</Capability>
26
- </Analysis>
27
- </CoreCapabilities>
28
- <AnalysisApproach>
29
- <Step>Assess data quality and structure</Step>
30
- <Step>Identify key patterns and relationships</Step>
31
- <Step>Perform statistical analysis</Step>
32
- <Step>Generate visualizations</Step>
33
- <Step>Provide actionable insights</Step>
34
- </AnalysisApproach>
35
- </DataScienceExpertFramework>
36
- """
37
 
38
- def format_stats_results(results: Dict) -> str:
39
- """Format statistical results for display"""
40
- formatted = []
41
- for test_name, result in results.items():
42
- if "normality" in test_name:
43
- formatted.append(f"- {test_name}: {'Normal' if result['is_normal'] else 'Non-normal'} "
44
- f"(p={result['p_value']:.4f})")
45
- elif "correlation" in test_name:
46
- formatted.append(f"- {test_name}: {result['correlation']:.4f} "
47
- f"(p={result['p_value']:.4f})")
48
- return "\n".join(formatted)
49
 
50
- def analyze_data(df: pd.DataFrame) -> Dict:
51
- """Analyze dataframe and return statistics"""
52
- analysis = {
53
- "shape": df.shape,
54
- "dtypes": df.dtypes.to_dict(),
55
- "missing": df.isnull().sum().to_dict(),
56
- "numeric_summary": df.describe().to_dict(),
57
- "correlations": {}
 
 
 
 
 
58
  }
59
-
60
- # Calculate correlations for numeric columns
61
- numeric_cols = df.select_dtypes(include=[np.number]).columns
62
- if len(numeric_cols) >= 2:
63
- corr_matrix = df[numeric_cols].corr()
64
- analysis["correlations"] = corr_matrix.to_dict()
65
-
66
- return analysis
67
 
68
- def create_visualizations(df: pd.DataFrame, save_dir: str = "figures") -> List[str]:
69
- """Create and save visualizations"""
70
- os.makedirs(save_dir, exist_ok=True)
71
- paths = []
72
-
73
- # Correlation heatmap
74
- numeric_cols = df.select_dtypes(include=[np.number]).columns
75
- if len(numeric_cols) >= 2:
76
- plt.figure(figsize=(10, 8))
77
- sns.heatmap(df[numeric_cols].corr(), annot=True, cmap='coolwarm')
78
- plt.title("Correlation Heatmap")
79
- path = os.path.join(save_dir, "correlation_heatmap.png")
80
- plt.savefig(path)
81
- plt.close()
82
- paths.append(path)
83
-
84
- # Distribution plots for numeric columns
85
- for col in numeric_cols[:5]: # Limit to first 5 columns
86
  plt.figure(figsize=(10, 6))
87
- sns.histplot(df[col], kde=True)
88
- plt.title(f"Distribution of {col}")
89
- path = os.path.join(save_dir, f"dist_{col}.png")
90
- plt.savefig(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  plt.close()
92
- paths.append(path)
93
-
94
- return paths
95
 
96
- def chat_with_data_scientist(message: str, history: List, api_key: str, df: Optional[pd.DataFrame] = None) -> List:
97
- """Chat with GPT-4o-mini about data analysis"""
98
- if not api_key:
99
- return history + [
100
- ("Please provide an API key to continue.", None)
101
- ]
102
-
103
- if df is None:
104
- return history + [
105
- ("Please upload a CSV file to analyze.", None)
106
- ]
107
-
108
- try:
109
- client = openai.OpenAI(api_key=api_key)
110
-
111
- # Create analysis summary
112
- analysis = analyze_data(df)
113
- analysis_text = f"""
114
- Dataset Shape: {analysis['shape']}
115
- Missing Values: {sum(analysis['missing'].values())}
116
- Numeric Columns: {list(analysis['numeric_summary'].keys())}
117
- """
118
-
119
- messages = [
120
- {"role": "system", "content": DATA_ANALYSIS_PROMPT},
121
- {"role": "system", "content": f"Analysis Context:\n{analysis_text}"},
122
- {"role": "user", "content": message}
123
- ]
124
-
125
- response = client.chat.completions.create(
126
- model="gpt-4o-mini",
127
- messages=messages,
128
- max_tokens=500
129
- )
130
-
131
- return history + [
132
- (message, response.choices[0].message.content)
133
- ]
134
 
135
- except Exception as e:
136
- return history + [
137
- (message, f"Error: {str(e)}")
138
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def create_demo():
 
 
 
 
 
 
 
 
 
 
 
141
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
142
- gr.Markdown("# 🔬 Data Science Expert")
143
 
144
  with gr.Row():
145
  with gr.Column():
146
  api_key = gr.Textbox(
147
- label="GPT-4o-mini API Key",
148
- placeholder="sk-...",
149
- type="password"
150
  )
151
  file_input = gr.File(
152
- label="Upload CSV file",
153
  file_types=[".csv"]
154
  )
155
- system_prompt = gr.Textbox(
156
- label="System Prompt",
157
- value=DATA_ANALYSIS_PROMPT,
158
- lines=5
159
- )
 
 
160
 
161
  with gr.Column():
162
  chat = gr.Chatbot(label="Analysis Chat")
@@ -166,7 +162,7 @@ def create_demo():
166
  )
167
  clear = gr.Button("Clear")
168
 
169
- # Store DataFrame in state
170
  df_state = gr.State(None)
171
 
172
  def process_file(file):
@@ -174,6 +170,29 @@ def create_demo():
174
  return None
175
  return pd.read_csv(file.name)
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  file_input.change(
178
  process_file,
179
  inputs=[file_input],
@@ -181,7 +200,7 @@ def create_demo():
181
  )
182
 
183
  msg.submit(
184
- chat_with_data_scientist,
185
  inputs=[msg, chat, api_key, df_state],
186
  outputs=[chat]
187
  )
@@ -190,9 +209,8 @@ def create_demo():
190
 
191
  return demo
192
 
193
- demo = create_demo()
194
-
195
  if __name__ == "__main__":
 
196
  demo.launch()
197
  else:
198
  demo.launch(show_api=False)
 
1
+ from transformers import Tool, ReactCodeAgent, HfApiEngine
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
  from typing import Dict, List, Optional
8
  import openai
9
+ import seaborn as sns
10
+ import matplotlib.pyplot as plt
11
+ import io
12
+ import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Custom Tools for Data Analysis
15
+ class DataVisualizationTool(Tool):
16
+ name = "data_visualizer"
17
+ description = """Creates various types of visualizations from data:
18
+ - Correlation heatmaps
19
+ - Distribution plots
20
+ - Scatter plots
21
+ - Time series plots
22
+ Returns the plots as base64 encoded images."""
 
 
23
 
24
+ inputs = {
25
+ "data": {
26
+ "type": "dict",
27
+ "description": "DataFrame as dictionary"
28
+ },
29
+ "plot_type": {
30
+ "type": "string",
31
+ "description": "Type of plot to create: 'heatmap', 'distribution', 'scatter'"
32
+ },
33
+ "columns": {
34
+ "type": "list",
35
+ "description": "List of columns to plot"
36
+ }
37
  }
38
+ output_type = "string" # base64 encoded image
 
 
 
 
 
 
 
39
 
40
+ def forward(self, data: Dict, plot_type: str, columns: List[str]) -> str:
41
+ df = pd.DataFrame(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  plt.figure(figsize=(10, 6))
43
+
44
+ if plot_type == "heatmap":
45
+ sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm')
46
+ plt.title("Correlation Heatmap")
47
+ elif plot_type == "distribution":
48
+ for col in columns:
49
+ sns.histplot(df[col], kde=True, label=col)
50
+ plt.title("Distribution Plot")
51
+ plt.legend()
52
+ elif plot_type == "scatter":
53
+ if len(columns) >= 2:
54
+ sns.scatterplot(data=df, x=columns[0], y=columns[1])
55
+ plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}")
56
+
57
+ # Convert plot to base64
58
+ buf = io.BytesIO()
59
+ plt.savefig(buf, format='png')
60
  plt.close()
61
+ buf.seek(0)
62
+ return base64.b64encode(buf.read()).decode('utf-8')
 
63
 
64
+ class DataAnalysisTool(Tool):
65
+ name = "data_analyzer"
66
+ description = """Performs statistical analysis on data:
67
+ - Basic statistics (mean, median, std)
68
+ - Correlation analysis
69
+ - Missing value analysis
70
+ - Outlier detection"""
71
+
72
+ inputs = {
73
+ "data": {
74
+ "type": "dict",
75
+ "description": "DataFrame as dictionary"
76
+ },
77
+ "analysis_type": {
78
+ "type": "string",
79
+ "description": "Type of analysis: 'basic', 'correlation', 'missing', 'outliers'"
80
+ },
81
+ "columns": {
82
+ "type": "list",
83
+ "description": "List of columns to analyze"
84
+ }
85
+ }
86
+ output_type = "dict"
87
+
88
+ def forward(self, data: Dict, analysis_type: str, columns: List[str]) -> Dict:
89
+ df = pd.DataFrame(data)
90
+ selected_cols = [col for col in columns if col in df.columns]
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ if analysis_type == "basic":
93
+ return {
94
+ "statistics": df[selected_cols].describe().to_dict(),
95
+ "skew": df[selected_cols].skew().to_dict(),
96
+ "kurtosis": df[selected_cols].kurtosis().to_dict()
97
+ }
98
+ elif analysis_type == "correlation":
99
+ numeric_cols = df[selected_cols].select_dtypes(include=[np.number])
100
+ return {
101
+ "correlation": numeric_cols.corr().to_dict(),
102
+ "covariance": numeric_cols.cov().to_dict()
103
+ }
104
+ elif analysis_type == "missing":
105
+ return {
106
+ "missing_counts": df[selected_cols].isnull().sum().to_dict(),
107
+ "missing_percentages": (df[selected_cols].isnull().mean() * 100).to_dict()
108
+ }
109
+ elif analysis_type == "outliers":
110
+ outliers = {}
111
+ for col in selected_cols:
112
+ if df[col].dtype in [np.float64, np.int64]:
113
+ Q1 = df[col].quantile(0.25)
114
+ Q3 = df[col].quantile(0.75)
115
+ IQR = Q3 - Q1
116
+ outliers[col] = {
117
+ "outliers_count": len(df[(df[col] < Q1 - 1.5 * IQR) | (df[col] > Q3 + 1.5 * IQR)]),
118
+ "lower_bound": Q1 - 1.5 * IQR,
119
+ "upper_bound": Q3 + 1.5 * IQR
120
+ }
121
+ return {"outliers": outliers}
122
 
123
  def create_demo():
124
+ # Initialize tools
125
+ viz_tool = DataVisualizationTool()
126
+ analysis_tool = DataAnalysisTool()
127
+
128
+ # Create agent with tools
129
+ llm_engine = HfApiEngine() # Uses default model
130
+ agent = ReactCodeAgent(
131
+ tools=[viz_tool, analysis_tool],
132
+ llm_engine=llm_engine
133
+ )
134
+
135
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
+ gr.Markdown("# 🔬 Advanced Data Analysis Agent")
137
 
138
  with gr.Row():
139
  with gr.Column():
140
  api_key = gr.Textbox(
141
+ label="OpenAI API Key",
142
+ type="password",
143
+ placeholder="sk-..."
144
  )
145
  file_input = gr.File(
146
+ label="Upload CSV",
147
  file_types=[".csv"]
148
  )
149
+ with gr.Accordion("Advanced Settings", open=False):
150
+ system_prompt = gr.Textbox(
151
+ label="System Prompt",
152
+ value="""You are a data science expert. Analyze the data and create
153
+ visualizations to help understand patterns and insights.""",
154
+ lines=3
155
+ )
156
 
157
  with gr.Column():
158
  chat = gr.Chatbot(label="Analysis Chat")
 
162
  )
163
  clear = gr.Button("Clear")
164
 
165
+ # State for storing the DataFrame
166
  df_state = gr.State(None)
167
 
168
  def process_file(file):
 
170
  return None
171
  return pd.read_csv(file.name)
172
 
173
+ def process_message(message, chat_history, api_key, df):
174
+ if df is None:
175
+ return chat_history + [(message, "Please upload a CSV file first.")]
176
+
177
+ try:
178
+ # Convert DataFrame to dict for tools
179
+ data_dict = df.to_dict()
180
+
181
+ # Get all columns for potential analysis
182
+ columns = list(df.columns)
183
+
184
+ # Use agent to analyze and create visualizations
185
+ response = agent.run(
186
+ f"""Analyze this data: {message}
187
+ Available columns: {columns}
188
+ Use the data_analyzer and data_visualizer tools to create insights."""
189
+ )
190
+
191
+ return chat_history + [(message, response)]
192
+
193
+ except Exception as e:
194
+ return chat_history + [(message, f"Error: {str(e)}")]
195
+
196
  file_input.change(
197
  process_file,
198
  inputs=[file_input],
 
200
  )
201
 
202
  msg.submit(
203
+ process_message,
204
  inputs=[msg, chat, api_key, df_state],
205
  outputs=[chat]
206
  )
 
209
 
210
  return demo
211
 
 
 
212
  if __name__ == "__main__":
213
+ demo = create_demo()
214
  demo.launch()
215
  else:
216
  demo.launch(show_api=False)