jzou19950715 commited on
Commit
5293476
·
verified ·
1 Parent(s): ad9e004

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -192
app.py CHANGED
@@ -1,216 +1,141 @@
1
- from transformers import Tool, ReactCodeAgent, HfApiEngine
2
- import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
  import plotly.express as px
6
- import plotly.graph_objects as go
7
- from typing import Dict, List, Optional
8
- import openai
9
  import seaborn as sns
10
- import matplotlib.pyplot as plt
11
- import io
12
- import base64
13
 
14
- # Custom Tools for Data Analysis
15
- class DataVisualizationTool(Tool):
16
- name = "data_visualizer"
17
- description = """Creates various types of visualizations from data:
18
- - Correlation heatmaps
19
- - Distribution plots
20
- - Scatter plots
21
- - Time series plots
22
- Returns the plots as base64 encoded images."""
23
 
24
  inputs = {
25
- "data": {
26
- "type": "dict",
27
- "description": "DataFrame as dictionary"
28
- },
29
- "plot_type": {
30
- "type": "string",
31
- "description": "Type of plot to create: 'heatmap', 'distribution', 'scatter'"
32
- },
33
- "columns": {
34
- "type": "list",
35
- "description": "List of columns to plot"
36
- }
37
  }
38
- output_type = "string" # base64 encoded image
39
 
40
- def forward(self, data: Dict, plot_type: str, columns: List[str]) -> str:
41
  df = pd.DataFrame(data)
42
- plt.figure(figsize=(10, 6))
43
-
44
- if plot_type == "heatmap":
45
- sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm')
46
- plt.title("Correlation Heatmap")
47
- elif plot_type == "distribution":
48
- for col in columns:
49
- sns.histplot(df[col], kde=True, label=col)
50
- plt.title("Distribution Plot")
51
- plt.legend()
52
- elif plot_type == "scatter":
53
- if len(columns) >= 2:
54
- sns.scatterplot(data=df, x=columns[0], y=columns[1])
55
- plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}")
 
56
 
57
- # Convert plot to base64
58
- buf = io.BytesIO()
59
- plt.savefig(buf, format='png')
60
- plt.close()
61
- buf.seek(0)
62
- return base64.b64encode(buf.read()).decode('utf-8')
63
-
64
- class DataAnalysisTool(Tool):
65
- name = "data_analyzer"
66
- description = """Performs statistical analysis on data:
67
- - Basic statistics (mean, median, std)
68
- - Correlation analysis
69
- - Missing value analysis
70
- - Outlier detection"""
71
 
72
  inputs = {
73
- "data": {
74
- "type": "dict",
75
- "description": "DataFrame as dictionary"
76
- },
77
- "analysis_type": {
78
- "type": "string",
79
- "description": "Type of analysis: 'basic', 'correlation', 'missing', 'outliers'"
80
- },
81
- "columns": {
82
- "type": "list",
83
- "description": "List of columns to analyze"
84
- }
85
  }
86
  output_type = "dict"
87
 
88
- def forward(self, data: Dict, analysis_type: str, columns: List[str]) -> Dict:
89
  df = pd.DataFrame(data)
90
- selected_cols = [col for col in columns if col in df.columns]
91
-
92
- if analysis_type == "basic":
93
- return {
94
- "statistics": df[selected_cols].describe().to_dict(),
95
- "skew": df[selected_cols].skew().to_dict(),
96
- "kurtosis": df[selected_cols].kurtosis().to_dict()
97
- }
98
- elif analysis_type == "correlation":
99
- numeric_cols = df[selected_cols].select_dtypes(include=[np.number])
100
- return {
101
- "correlation": numeric_cols.corr().to_dict(),
102
- "covariance": numeric_cols.cov().to_dict()
103
- }
104
- elif analysis_type == "missing":
105
  return {
106
- "missing_counts": df[selected_cols].isnull().sum().to_dict(),
107
- "missing_percentages": (df[selected_cols].isnull().mean() * 100).to_dict()
 
108
  }
109
- elif analysis_type == "outliers":
110
- outliers = {}
111
- for col in selected_cols:
112
- if df[col].dtype in [np.float64, np.int64]:
113
- Q1 = df[col].quantile(0.25)
114
- Q3 = df[col].quantile(0.75)
115
- IQR = Q3 - Q1
116
- outliers[col] = {
117
- "outliers_count": len(df[(df[col] < Q1 - 1.5 * IQR) | (df[col] > Q3 + 1.5 * IQR)]),
118
- "lower_bound": Q1 - 1.5 * IQR,
119
- "upper_bound": Q3 + 1.5 * IQR
120
- }
121
- return {"outliers": outliers}
122
 
123
- def create_demo():
124
- # Initialize tools
125
- viz_tool = DataVisualizationTool()
126
- analysis_tool = DataAnalysisTool()
127
-
128
- # Create agent with tools
129
- llm_engine = HfApiEngine() # Uses default model
130
- agent = ReactCodeAgent(
131
- tools=[viz_tool, analysis_tool],
132
- llm_engine=llm_engine
133
- )
134
 
135
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
- gr.Markdown("# 🔬 Advanced Data Analysis Agent")
137
-
138
- with gr.Row():
139
- with gr.Column():
140
- api_key = gr.Textbox(
141
- label="OpenAI API Key",
142
- type="password",
143
- placeholder="sk-..."
144
- )
145
- file_input = gr.File(
146
- label="Upload CSV",
147
- file_types=[".csv"]
148
- )
149
- with gr.Accordion("Advanced Settings", open=False):
150
- system_prompt = gr.Textbox(
151
- label="System Prompt",
152
- value="""You are a data science expert. Analyze the data and create
153
- visualizations to help understand patterns and insights.""",
154
- lines=3
155
- )
156
-
157
- with gr.Column():
158
- chat = gr.Chatbot(label="Analysis Chat")
159
- msg = gr.Textbox(
160
- label="Ask about your data",
161
- placeholder="What insights can you find in this dataset?"
162
- )
163
- clear = gr.Button("Clear")
164
-
165
- # State for storing the DataFrame
166
- df_state = gr.State(None)
167
-
168
- def process_file(file):
169
- if file is None:
170
- return None
171
- return pd.read_csv(file.name)
172
-
173
- def process_message(message, chat_history, api_key, df):
174
- if df is None:
175
- return chat_history + [(message, "Please upload a CSV file first.")]
176
-
177
- try:
178
- # Convert DataFrame to dict for tools
179
- data_dict = df.to_dict()
180
-
181
- # Get all columns for potential analysis
182
- columns = list(df.columns)
183
-
184
- # Use agent to analyze and create visualizations
185
- response = agent.run(
186
- f"""Analyze this data: {message}
187
- Available columns: {columns}
188
- Use the data_analyzer and data_visualizer tools to create insights."""
189
- )
190
-
191
- return chat_history + [(message, response)]
192
-
193
- except Exception as e:
194
- return chat_history + [(message, f"Error: {str(e)}")]
195
-
196
- file_input.change(
197
- process_file,
198
- inputs=[file_input],
199
- outputs=[df_state]
200
- )
201
 
202
- msg.submit(
203
- process_message,
204
- inputs=[msg, chat, api_key, df_state],
205
- outputs=[chat]
206
- )
207
 
208
- clear.click(lambda: None, None, chat)
209
 
210
- return demo
211
-
212
- if __name__ == "__main__":
213
- demo = create_demo()
214
- demo.launch()
215
- else:
216
- demo.launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Tool
 
2
  import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
 
 
 
5
  import seaborn as sns
6
+ from sklearn import preprocessing, decomposition, metrics
 
 
7
 
8
+ # 1. Data Loading and Preprocessing Tool
9
+ class DataPreprocessingTool(Tool):
10
+ name = "data_preprocessor"
11
+ description = "Handles data loading, cleaning, and preprocessing tasks"
 
 
 
 
 
12
 
13
  inputs = {
14
+ "data": {"type": "dict", "description": "Input data dictionary"},
15
+ "operation": {"type": "string", "description": "Operation to perform: clean/encode/normalize/impute"}
 
 
 
 
 
 
 
 
 
 
16
  }
17
+ output_type = "dict"
18
 
19
+ def forward(self, data: dict, operation: str) -> dict:
20
  df = pd.DataFrame(data)
21
+ if operation == "clean":
22
+ # Handle duplicates, missing values
23
+ df = df.drop_duplicates()
24
+ df = df.fillna(df.mean(numeric_only=True))
25
+ elif operation == "encode":
26
+ # Encode categorical variables
27
+ le = preprocessing.LabelEncoder()
28
+ for col in df.select_dtypes(include=['object']):
29
+ df[col] = le.fit_transform(df[col].astype(str))
30
+ elif operation == "normalize":
31
+ # Normalize numeric columns
32
+ scaler = preprocessing.StandardScaler()
33
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
34
+ df[numeric_cols] = scaler.fit_transform(df[numeric_cols])
35
+ return df.to_dict()
36
 
37
+ # 2. Statistical Analysis Tool
38
+ class StatisticalAnalysisTool(Tool):
39
+ name = "statistical_analyzer"
40
+ description = "Performs statistical analysis on data"
 
 
 
 
 
 
 
 
 
 
41
 
42
  inputs = {
43
+ "data": {"type": "dict", "description": "Input data dictionary"},
44
+ "analysis_type": {"type": "string", "description": "Type of analysis: descriptive/inferential/correlation"}
 
 
 
 
 
 
 
 
 
 
45
  }
46
  output_type = "dict"
47
 
48
+ def forward(self, data: dict, analysis_type: str) -> dict:
49
  df = pd.DataFrame(data)
50
+ if analysis_type == "descriptive":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return {
52
+ "summary": df.describe().to_dict(),
53
+ "skewness": df.skew().to_dict(),
54
+ "kurtosis": df.kurtosis().to_dict()
55
  }
56
+ elif analysis_type == "inferential":
57
+ # Perform statistical tests
58
+ results = {}
59
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
60
+ for col in numeric_cols:
61
+ from scipy import stats
62
+ stat, p_value = stats.normaltest(df[col].dropna())
63
+ results[col] = {"statistic": stat, "p_value": p_value}
64
+ return results
65
+ return df.corr().to_dict()
 
 
 
66
 
67
+ # 3. Advanced Visualization Tool
68
+ class AdvancedVisualizationTool(Tool):
69
+ name = "advanced_visualizer"
70
+ description = "Creates advanced statistical and ML visualizations"
 
 
 
 
 
 
 
71
 
72
+ inputs = {
73
+ "data": {"type": "dict", "description": "Input data dictionary"},
74
+ "viz_type": {"type": "string", "description": "Type of visualization"},
75
+ "params": {"type": "dict", "description": "Additional parameters"}
76
+ }
77
+ output_type = "dict"
78
+
79
+ def forward(self, data: dict, viz_type: str, params: dict) -> dict:
80
+ df = pd.DataFrame(data)
81
+ if viz_type == "pca":
82
+ # PCA visualization
83
+ pca = decomposition.PCA(n_components=2)
84
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
85
+ pca_result = pca.fit_transform(df[numeric_cols])
86
+ fig = px.scatter(x=pca_result[:, 0], y=pca_result[:, 1],
87
+ title='PCA Visualization')
88
+ return {"plot": fig.to_dict()}
89
+ elif viz_type == "cluster":
90
+ # Clustering visualization
91
+ from sklearn.cluster import KMeans
92
+ kmeans = KMeans(n_clusters=params.get("n_clusters", 3))
93
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
94
+ clusters = kmeans.fit_predict(df[numeric_cols])
95
+ fig = px.scatter(df, x=params.get("x"), y=params.get("y"),
96
+ color=clusters, title='Cluster Visualization')
97
+ return {"plot": fig.to_dict()}
98
+ return {}
99
+
100
+ # 4. Machine Learning Tool
101
+ class MLModelTool(Tool):
102
+ name = "ml_modeler"
103
+ description = "Trains and evaluates machine learning models"
104
+
105
+ inputs = {
106
+ "data": {"type": "dict", "description": "Input data dictionary"},
107
+ "target": {"type": "string", "description": "Target column name"},
108
+ "model_type": {"type": "string", "description": "Type of model to train"}
109
+ }
110
+ output_type = "dict"
111
+
112
+ def forward(self, data: dict, target: str, model_type: str) -> dict:
113
+ from sklearn.model_selection import train_test_split
114
+ from sklearn.metrics import mean_squared_error, accuracy_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ df = pd.DataFrame(data)
117
+ X = df.drop(columns=[target])
118
+ y = df[target]
 
 
119
 
120
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
121
 
122
+ if model_type == "regression":
123
+ from sklearn.linear_model import LinearRegression
124
+ model = LinearRegression()
125
+ model.fit(X_train, y_train)
126
+ y_pred = model.predict(X_test)
127
+ return {
128
+ "mse": mean_squared_error(y_test, y_pred),
129
+ "r2": model.score(X_test, y_test),
130
+ "coefficients": dict(zip(X.columns, model.coef_))
131
+ }
132
+ elif model_type == "classification":
133
+ from sklearn.ensemble import RandomForestClassifier
134
+ model = RandomForestClassifier()
135
+ model.fit(X_train, y_train)
136
+ y_pred = model.predict(X_test)
137
+ return {
138
+ "accuracy": accuracy_score(y_test, y_pred),
139
+ "feature_importance": dict(zip(X.columns, model.feature_importances_))
140
+ }
141
+ return {}