jzou19950715 commited on
Commit
37336a7
·
verified ·
1 Parent(s): 6404c67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -133
app.py CHANGED
@@ -1,141 +1,391 @@
1
- from transformers import Tool
 
2
  import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
 
5
  import seaborn as sns
6
- from sklearn import preprocessing, decomposition, metrics
7
-
8
- # 1. Data Loading and Preprocessing Tool
9
- class DataPreprocessingTool(Tool):
10
- name = "data_preprocessor"
11
- description = "Handles data loading, cleaning, and preprocessing tasks"
12
-
13
- inputs = {
14
- "data": {"type": "dict", "description": "Input data dictionary"},
15
- "operation": {"type": "string", "description": "Operation to perform: clean/encode/normalize/impute"}
16
- }
17
- output_type = "dict"
18
-
19
- def forward(self, data: dict, operation: str) -> dict:
20
- df = pd.DataFrame(data)
21
- if operation == "clean":
22
- # Handle duplicates, missing values
23
- df = df.drop_duplicates()
24
- df = df.fillna(df.mean(numeric_only=True))
25
- elif operation == "encode":
26
- # Encode categorical variables
27
- le = preprocessing.LabelEncoder()
28
- for col in df.select_dtypes(include=['object']):
29
- df[col] = le.fit_transform(df[col].astype(str))
30
- elif operation == "normalize":
31
- # Normalize numeric columns
32
- scaler = preprocessing.StandardScaler()
33
- numeric_cols = df.select_dtypes(include=[np.number]).columns
34
- df[numeric_cols] = scaler.fit_transform(df[numeric_cols])
35
- return df.to_dict()
36
-
37
- # 2. Statistical Analysis Tool
38
- class StatisticalAnalysisTool(Tool):
39
- name = "statistical_analyzer"
40
- description = "Performs statistical analysis on data"
41
-
42
- inputs = {
43
- "data": {"type": "dict", "description": "Input data dictionary"},
44
- "analysis_type": {"type": "string", "description": "Type of analysis: descriptive/inferential/correlation"}
45
- }
46
- output_type = "dict"
47
-
48
- def forward(self, data: dict, analysis_type: str) -> dict:
49
- df = pd.DataFrame(data)
50
- if analysis_type == "descriptive":
51
- return {
52
- "summary": df.describe().to_dict(),
53
- "skewness": df.skew().to_dict(),
54
- "kurtosis": df.kurtosis().to_dict()
55
- }
56
- elif analysis_type == "inferential":
57
- # Perform statistical tests
58
- results = {}
59
- numeric_cols = df.select_dtypes(include=[np.number]).columns
60
- for col in numeric_cols:
61
- from scipy import stats
62
- stat, p_value = stats.normaltest(df[col].dropna())
63
- results[col] = {"statistic": stat, "p_value": p_value}
64
- return results
65
- return df.corr().to_dict()
66
-
67
- # 3. Advanced Visualization Tool
68
- class AdvancedVisualizationTool(Tool):
69
- name = "advanced_visualizer"
70
- description = "Creates advanced statistical and ML visualizations"
71
-
72
- inputs = {
73
- "data": {"type": "dict", "description": "Input data dictionary"},
74
- "viz_type": {"type": "string", "description": "Type of visualization"},
75
- "params": {"type": "dict", "description": "Additional parameters"}
76
- }
77
- output_type = "dict"
78
-
79
- def forward(self, data: dict, viz_type: str, params: dict) -> dict:
80
- df = pd.DataFrame(data)
81
- if viz_type == "pca":
82
- # PCA visualization
83
- pca = decomposition.PCA(n_components=2)
84
- numeric_cols = df.select_dtypes(include=[np.number]).columns
85
- pca_result = pca.fit_transform(df[numeric_cols])
86
- fig = px.scatter(x=pca_result[:, 0], y=pca_result[:, 1],
87
- title='PCA Visualization')
88
- return {"plot": fig.to_dict()}
89
- elif viz_type == "cluster":
90
- # Clustering visualization
91
- from sklearn.cluster import KMeans
92
- kmeans = KMeans(n_clusters=params.get("n_clusters", 3))
93
- numeric_cols = df.select_dtypes(include=[np.number]).columns
94
- clusters = kmeans.fit_predict(df[numeric_cols])
95
- fig = px.scatter(df, x=params.get("x"), y=params.get("y"),
96
- color=clusters, title='Cluster Visualization')
97
- return {"plot": fig.to_dict()}
98
- return {}
99
-
100
- # 4. Machine Learning Tool
101
- class MLModelTool(Tool):
102
- name = "ml_modeler"
103
- description = "Trains and evaluates machine learning models"
104
-
105
- inputs = {
106
- "data": {"type": "dict", "description": "Input data dictionary"},
107
- "target": {"type": "string", "description": "Target column name"},
108
- "model_type": {"type": "string", "description": "Type of model to train"}
109
- }
110
- output_type = "dict"
111
-
112
- def forward(self, data: dict, target: str, model_type: str) -> dict:
113
- from sklearn.model_selection import train_test_split
114
- from sklearn.metrics import mean_squared_error, accuracy_score
115
-
116
- df = pd.DataFrame(data)
117
- X = df.drop(columns=[target])
118
- y = df[target]
119
 
120
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- if model_type == "regression":
123
- from sklearn.linear_model import LinearRegression
124
- model = LinearRegression()
125
- model.fit(X_train, y_train)
126
- y_pred = model.predict(X_test)
127
- return {
128
- "mse": mean_squared_error(y_test, y_pred),
129
- "r2": model.score(X_test, y_test),
130
- "coefficients": dict(zip(X.columns, model.coef_))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  }
132
- elif model_type == "classification":
133
- from sklearn.ensemble import RandomForestClassifier
134
- model = RandomForestClassifier()
135
- model.fit(X_train, y_train)
136
- y_pred = model.predict(X_test)
137
- return {
138
- "accuracy": accuracy_score(y_test, y_pred),
139
- "feature_importance": dict(zip(X.columns, model.feature_importances_))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  }
141
- return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
  import plotly.express as px
6
+ import plotly.graph_objects as go
7
  import seaborn as sns
8
+ import matplotlib.pyplot as plt
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ from dataclasses import dataclass
11
+ from transformers import Tool, ReactCodeAgent, HfApiEngine
12
+ import openai
13
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
14
+ from sklearn.model_selection import train_test_split
15
+ import statsmodels.api as sm
16
+ import json
17
+ import base64
18
+ import io
19
+
20
+ # Configuration class for agent settings
21
+ @dataclass
22
+ class AgentConfig:
23
+ """Configuration for the data science agent"""
24
+ system_prompt: str = """
25
+ <DataScienceExpertFramework version="2.0">
26
+ <Identity>
27
+ <Role>Expert Data Scientist and ML Engineer</Role>
28
+ <Expertise>
29
+ <Area>Statistical Analysis</Area>
30
+ <Area>Machine Learning</Area>
31
+ <Area>Data Visualization</Area>
32
+ <Area>Feature Engineering</Area>
33
+ <Area>Time Series Analysis</Area>
34
+ </Expertise>
35
+ </Identity>
36
+ <Capabilities>
37
+ <DataProcessing>
38
+ <Task>Data Cleaning</Task>
39
+ <Task>Feature Engineering</Task>
40
+ <Task>Preprocessing</Task>
41
+ </DataProcessing>
42
+ <Analysis>
43
+ <Task>Statistical Testing</Task>
44
+ <Task>Pattern Recognition</Task>
45
+ <Task>Correlation Analysis</Task>
46
+ </Analysis>
47
+ <MachineLearning>
48
+ <Task>Model Selection</Task>
49
+ <Task>Training</Task>
50
+ <Task>Evaluation</Task>
51
+ </MachineLearning>
52
+ <Visualization>
53
+ <Task>EDA Plots</Task>
54
+ <Task>Statistical Plots</Task>
55
+ <Task>Model Performance Plots</Task>
56
+ </Visualization>
57
+ </Capabilities>
58
+ <OutputFormat>
59
+ <Format>Clear Explanations</Format>
60
+ <Format>Statistical Evidence</Format>
61
+ <Format>Visual Support</Format>
62
+ <Format>Actionable Insights</Format>
63
+ </OutputFormat>
64
+ </DataScienceExpertFramework>
65
+ """
66
+ max_iterations: int = 10
67
+ temperature: float = 0.7
68
+ model_name: str = "gpt-4o-mini"
69
+
70
+ # Data Analysis State class
71
+ @dataclass
72
+ class AnalysisState:
73
+ """Maintains state for ongoing analysis"""
74
+ df: Optional[pd.DataFrame] = None
75
+ current_analysis: Dict = None
76
+ visualizations: List[Dict] = None
77
+ model_results: Dict = None
78
+ error_log: List[str] = None
79
+
80
+ def clear(self):
81
+ self.df = None
82
+ self.current_analysis = None
83
+ self.visualizations = None
84
+ self.model_results = None
85
+ self.error_log = []
86
+
87
+ def log_error(self, error: str):
88
+ if self.error_log is None:
89
+ self.error_log = []
90
+ self.error_log.append(error)
91
+
92
+ # Helper functions for data processing
93
+ def process_uploaded_file(file) -> Tuple[Optional[pd.DataFrame], Dict]:
94
+ """Process uploaded file and return DataFrame with info"""
95
+ try:
96
+ if file.name.endswith('.csv'):
97
+ df = pd.read_csv(file.name)
98
+ elif file.name.endswith('.xlsx'):
99
+ df = pd.read_excel(file.name)
100
+ elif file.name.endswith('.json'):
101
+ df = pd.read_json(file.name)
102
+ else:
103
+ return None, {"error": "Unsupported file format"}
104
+
105
+ info = {
106
+ "shape": df.shape,
107
+ "columns": list(df.columns),
108
+ "dtypes": df.dtypes.to_dict(),
109
+ "missing_values": df.isnull().sum().to_dict(),
110
+ "numeric_columns": list(df.select_dtypes(include=[np.number]).columns),
111
+ "categorical_columns": list(df.select_dtypes(exclude=[np.number]).columns)
112
+ }
 
 
 
 
 
 
 
 
113
 
114
+ return df, info
115
+ except Exception as e:
116
+ return None, {"error": str(e)}
117
+
118
+ def create_visualization(data: pd.DataFrame, viz_type: str, params: Dict) -> Optional[Dict]:
119
+ """Create visualization based on type and parameters"""
120
+ try:
121
+ if viz_type == "scatter":
122
+ fig = px.scatter(
123
+ data,
124
+ x=params["x"],
125
+ y=params["y"],
126
+ color=params.get("color"),
127
+ title=params.get("title", "Scatter Plot")
128
+ )
129
+ elif viz_type == "histogram":
130
+ fig = px.histogram(
131
+ data,
132
+ x=params["x"],
133
+ nbins=params.get("nbins", 30),
134
+ title=params.get("title", "Distribution")
135
+ )
136
+ elif viz_type == "line":
137
+ fig = px.line(
138
+ data,
139
+ x=params["x"],
140
+ y=params["y"],
141
+ title=params.get("title", "Line Plot")
142
+ )
143
+ elif viz_type == "heatmap":
144
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
145
+ corr = data[numeric_cols].corr()
146
+ fig = px.imshow(
147
+ corr,
148
+ labels=dict(color="Correlation"),
149
+ title=params.get("title", "Correlation Heatmap")
150
+ )
151
+ else:
152
+ return None
153
+
154
+ return fig.to_dict()
155
+ except Exception as e:
156
+ return {"error": str(e)}
157
+
158
+ class ChatInterface:
159
+ """Manages the chat interface and message handling"""
160
+ def __init__(self, agent_config: AgentConfig):
161
+ self.config = agent_config
162
+ self.history = []
163
+ self.agent = self._create_agent()
164
+
165
+ def _create_agent(self) -> ReactCodeAgent:
166
+ """Initialize the agent with tools"""
167
+ tools = self._get_tools()
168
+ llm_engine = HfApiEngine()
169
+ return ReactCodeAgent(
170
+ tools=tools,
171
+ llm_engine=llm_engine,
172
+ max_iterations=self.config.max_iterations
173
+ )
174
+
175
+ def _get_tools(self) -> List[Tool]:
176
+ """Get list of available tools"""
177
+ # Import tools from our tools.py
178
+ from tools import (
179
+ DataPreprocessingTool,
180
+ StatisticalAnalysisTool,
181
+ VisualizationTool,
182
+ MLModelTool,
183
+ TimeSeriesAnalysisTool
184
+ )
185
 
186
+ return [
187
+ DataPreprocessingTool(),
188
+ StatisticalAnalysisTool(),
189
+ VisualizationTool(),
190
+ MLModelTool(),
191
+ TimeSeriesAnalysisTool()
192
+ ]
193
+
194
+ def process_message(self, message: str, analysis_state: AnalysisState) -> Tuple[List, Any]:
195
+ """Process a message and return updated chat history and results"""
196
+ try:
197
+ if analysis_state.df is None:
198
+ return self.history + [(message, "Please upload a data file first.")], None
199
+
200
+ # Prepare context for the agent
201
+ context = {
202
+ "data_info": {
203
+ "shape": analysis_state.df.shape,
204
+ "columns": list(analysis_state.df.columns),
205
+ "dtypes": analysis_state.df.dtypes.to_dict()
206
+ },
207
+ "current_analysis": analysis_state.current_analysis,
208
+ "available_tools": [tool.name for tool in self._get_tools()]
209
  }
210
+
211
+ # Run agent
212
+ response = self.agent.run(
213
+ f"Context: {json.dumps(context)}\nUser request: {message}"
214
+ )
215
+
216
+ self.history.append((message, response))
217
+ return self.history, response
218
+ except Exception as e:
219
+ error_msg = f"Error processing message: {str(e)}"
220
+ analysis_state.log_error(error_msg)
221
+ return self.history + [(message, error_msg)], None
222
+
223
+ def create_demo():
224
+ # Initialize configuration and state
225
+ config = AgentConfig()
226
+ analysis_state = AnalysisState()
227
+ chat_interface = ChatInterface(config)
228
+
229
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
230
+ gr.Markdown("# 🔬 Advanced Data Science Agent")
231
+
232
+ with gr.Row():
233
+ with gr.Column(scale=1):
234
+ api_key = gr.Textbox(
235
+ label="API Key (GPT-4o-mini)",
236
+ type="password",
237
+ placeholder="sk-..."
238
+ )
239
+
240
+ file_input = gr.File(
241
+ label="Upload Data",
242
+ file_types=[".csv", ".xlsx", ".json"]
243
+ )
244
+
245
+ with gr.Accordion("Analysis Settings", open=False):
246
+ analysis_type = gr.Radio(
247
+ choices=[
248
+ "Exploratory Analysis",
249
+ "Statistical Analysis",
250
+ "Machine Learning",
251
+ "Time Series Analysis"
252
+ ],
253
+ label="Analysis Type",
254
+ value="Exploratory Analysis"
255
+ )
256
+
257
+ visualization_type = gr.Dropdown(
258
+ choices=[
259
+ "Automatic",
260
+ "Scatter Plots",
261
+ "Distributions",
262
+ "Correlations",
263
+ "Time Series"
264
+ ],
265
+ label="Visualization Type",
266
+ value="Automatic"
267
+ )
268
+
269
+ model_params = gr.JSON(
270
+ label="Model Parameters",
271
+ value={
272
+ "test_size": 0.2,
273
+ "n_estimators": 100,
274
+ "handle_outliers": True
275
+ }
276
+ )
277
+
278
+ with gr.Accordion("System Settings", open=False):
279
+ system_prompt = gr.Textbox(
280
+ label="System Prompt",
281
+ value=config.system_prompt,
282
+ lines=10
283
+ )
284
+
285
+ max_iterations = gr.Slider(
286
+ minimum=1,
287
+ maximum=20,
288
+ value=config.max_iterations,
289
+ step=1,
290
+ label="Max Iterations"
291
+ )
292
+
293
+ with gr.Column(scale=2):
294
+ # Chat interface
295
+ chatbot = gr.Chatbot(
296
+ label="Analysis Chat",
297
+ height=400
298
+ )
299
+
300
+ with gr.Row():
301
+ text_input = gr.Textbox(
302
+ label="Ask about your data",
303
+ placeholder="What would you like to analyze?",
304
+ lines=2
305
+ )
306
+ submit_btn = gr.Button("Analyze", variant="primary")
307
+
308
+ with gr.Row():
309
+ clear_btn = gr.Button("Clear Chat")
310
+ example_btn = gr.Button("Load Example")
311
+
312
+ # Output displays
313
+ with gr.Accordion("Visualization", open=True):
314
+ plot_output = gr.Plot(label="Generated Plots")
315
+
316
+ with gr.Accordion("Analysis Results", open=True):
317
+ results_json = gr.JSON(label="Detailed Results")
318
+
319
+ with gr.Accordion("Error Log", open=False):
320
+ error_output = gr.Textbox(label="Errors", lines=3)
321
+
322
+ # Event handlers
323
+ def handle_file_upload(file):
324
+ df, info = process_uploaded_file(file)
325
+ if df is not None:
326
+ analysis_state.df = df
327
+ analysis_state.current_analysis = info
328
+ return info, None
329
+ return {"error": "Failed to load file"}, "Failed to load file"
330
+
331
+ def handle_analysis(message, chat_history):
332
+ history, response = chat_interface.process_message(message, analysis_state)
333
+ return history
334
+
335
+ def handle_clear():
336
+ analysis_state.clear()
337
+ chat_interface.history = []
338
+ return None, None, None, None, None
339
+
340
+ def load_example_data():
341
+ import sklearn.datasets
342
+ data = sklearn.datasets.load_diabetes()
343
+ df = pd.DataFrame(data.data, columns=data.feature_names)
344
+ df['target'] = data.target
345
+
346
+ analysis_state.df = df
347
+ analysis_state.current_analysis = {
348
+ "shape": df.shape,
349
+ "columns": list(df.columns),
350
+ "dtypes": df.dtypes.to_dict()
351
  }
352
+
353
+ return analysis_state.current_analysis, None
354
+
355
+ # Connect event handlers
356
+ file_input.change(
357
+ handle_file_upload,
358
+ inputs=[file_input],
359
+ outputs=[results_json, error_output]
360
+ )
361
+
362
+ submit_btn.click(
363
+ handle_analysis,
364
+ inputs=[text_input, chatbot],
365
+ outputs=[chatbot]
366
+ )
367
+
368
+ text_input.submit(
369
+ handle_analysis,
370
+ inputs=[text_input, chatbot],
371
+ outputs=[chatbot]
372
+ )
373
+
374
+ clear_btn.click(
375
+ handle_clear,
376
+ outputs=[chatbot, plot_output, results_json, error_output, file_input]
377
+ )
378
+
379
+ example_btn.click(
380
+ load_example_data,
381
+ outputs=[results_json, error_output]
382
+ )
383
+
384
+ return demo
385
+
386
+ if __name__ == "__main__":
387
+ demo = create_demo()
388
+ demo.launch(share=True)
389
+ else:
390
+ demo = create_demo()
391
+ demo.launch(show_api=False)