jzou19950715 commited on
Commit
cedb0a7
·
verified ·
1 Parent(s): dbaa1d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -316
app.py CHANGED
@@ -1,346 +1,198 @@
1
  import os
2
- import requests
3
  import gradio as gr
4
  import pandas as pd
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
- from typing import Dict, List, Tuple, Optional
 
9
  from dataclasses import dataclass
10
- from sklearn.preprocessing import StandardScaler, LabelEncoder
 
11
  from sklearn.model_selection import train_test_split
12
- from sklearn.metrics import mean_squared_error, r2_score, accuracy_score
13
- from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
14
- from sklearn.impute import SimpleImputer
15
  import statsmodels.api as sm
16
- import plotly.express as px
17
- import plotly.graph_objects as go
18
- from scipy import stats
19
 
20
- @dataclass
21
- class AnalysisConfig:
22
- """Configuration for analysis parameters"""
23
- max_iterations: int = 5
24
- min_samples_for_analysis: int = 30
25
- correlation_threshold: float = 0.7
26
- max_categories_for_viz: int = 10
27
- significance_level: float = 0.05
28
-
29
- class DataAnalyzer:
30
- """Intelligent data analysis agent that determines appropriate visualizations and analyses"""
31
-
32
- def __init__(self, api_key: str):
33
- self.api_key = api_key
34
- self.config = AnalysisConfig()
35
- self.current_iteration = 0
36
- self.analysis_results = []
37
-
38
- def call_gpt4o_mini(self, prompt: str, system_prompt: str) -> str:
39
- """Call GPT-4o-mini API with proper error handling"""
40
- try:
41
- client = openai.OpenAI(api_key=self.api_key)
42
- messages = [
43
- {"role": "system", "content": system_prompt},
44
- {"role": "user", "content": prompt}
45
- ]
46
-
47
- response = client.chat.completions.create(
48
- model="gpt-4o-mini",
49
- messages=messages,
50
- max_tokens=500,
51
- temperature=0.7
52
- )
53
- return response.choices[0].message.content
54
- except Exception as e:
55
- return f"API Error: {str(e)}"
56
 
57
- def evaluate_code(self, code: str, state: Dict = None) -> Tuple[Any, str]:
58
- """Safely evaluate Python code with proper state management and security"""
59
- if state is None:
60
- state = {"print_outputs": ""}
61
-
62
- # Create safe environment with allowed imports
63
- safe_env = {
64
- "pd": pd,
65
- "np": np,
66
- "plt": plt,
67
- "sns": sns,
68
- "stats": stats,
69
- "print": lambda *args: state.update({"print_outputs": state["print_outputs"] + " ".join(map(str, args)) + "\n"}),
70
- }
71
-
72
- try:
73
- exec(code, safe_env, state)
74
- return state.get("result", None), state["print_outputs"]
75
- except Exception as e:
76
- raise RuntimeError(f"Code execution failed: {str(e)}")
77
 
78
- def analyze_data_types(self, df: pd.DataFrame) -> Dict:
79
- """Analyze data types and basic statistics of the DataFrame"""
80
- analysis = {
81
- "numeric_cols": df.select_dtypes(include=['int64', 'float64']).columns.tolist(),
82
- "categorical_cols": df.select_dtypes(include=['object', 'category']).columns.tolist(),
83
- "temporal_cols": df.select_dtypes(include=['datetime64']).columns.tolist(),
84
- "missing_values": df.isnull().sum().to_dict(),
85
- "unique_counts": df.nunique().to_dict()
86
- }
87
- return analysis
 
 
 
 
 
 
 
88
 
89
- def create_visualization(self, df: pd.DataFrame, viz_type: str, columns: List[str]) -> str:
90
- """Create and save visualization based on data types and relationships"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  plt.figure(figsize=(10, 6))
92
-
93
- if viz_type == "correlation":
94
- sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm')
95
- plt.title("Correlation Matrix")
96
- elif viz_type == "distribution":
97
- for col in columns:
98
- sns.histplot(data=df, x=col, kde=True)
99
- plt.title(f"Distribution of {col}")
100
- elif viz_type == "boxplot":
101
- sns.boxplot(data=df[columns])
102
- plt.title("Box Plot of Numeric Variables")
103
-
104
- output_path = f"viz_{self.current_iteration}.png"
105
- plt.savefig(output_path)
106
  plt.close()
107
- return output_path
 
 
108
 
109
- def perform_statistical_tests(self, df: pd.DataFrame, data_types: Dict) -> Dict:
110
- """Perform relevant statistical tests based on data types"""
111
- results = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Normality tests for numeric columns
114
- for col in data_types["numeric_cols"]:
115
- if len(df[col].dropna()) > 3:
116
- stat, p_value = stats.normaltest(df[col].dropna())
117
- results[f"normality_{col}"] = {
118
- "statistic": stat,
119
- "p_value": p_value,
120
- "is_normal": p_value > self.config.significance_level
121
- }
122
 
123
- # Chi-square tests for categorical columns
124
- for col1 in data_types["categorical_cols"]:
125
- for col2 in data_types["categorical_cols"]:
126
- if col1 < col2:
127
- contingency = pd.crosstab(df[col1], df[col2])
128
- chi2, p_value, _, _ = stats.chi2_contingency(contingency)
129
- results[f"chi2_{col1}_{col2}"] = {
130
- "statistic": chi2,
131
- "p_value": p_value,
132
- "is_significant": p_value < self.config.significance_level
133
- }
134
 
135
- return results
 
 
 
 
 
 
 
136
 
137
- def train_predictive_model(self, df: pd.DataFrame, target_col: str) -> Tuple[float, str]:
138
- """Train and evaluate a predictive model based on data characteristics"""
139
- X = df.drop(columns=[target_col])
140
- y = df[target_col]
141
 
142
- # Preprocessing
143
- numeric_transformer = Pipeline([
144
- ('imputer', SimpleImputer(strategy='median')),
145
- ('scaler', StandardScaler())
146
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- categorical_transformer = Pipeline([
149
- ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
150
- ('onehot', OneHotEncoder(handle_unknown='ignore'))
151
- ])
152
 
153
- preprocessor = ColumnTransformer(
154
- transformers=[
155
- ('num', numeric_transformer, X.select_dtypes(include=['int64', 'float64']).columns),
156
- ('cat', categorical_transformer, X.select_dtypes(include=['object']).columns)
157
- ])
158
 
159
- if len(np.unique(y)) <= 5: # Classification
160
- model = RandomForestClassifier(n_estimators=100, random_state=42)
161
- metric = 'accuracy'
162
- else: # Regression
163
- model = RandomForestRegressor(n_estimators=100, random_state=42)
164
- metric = 'r2'
165
-
166
- pipeline = Pipeline([
167
- ('preprocessor', preprocessor),
168
- ('model', model)
169
- ])
170
 
171
- # Train and evaluate
172
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
173
- pipeline.fit(X_train, y_train)
174
- y_pred = pipeline.predict(X_test)
 
175
 
176
- if metric == 'accuracy':
177
- score = accuracy_score(y_test, y_pred)
178
- else:
179
- score = r2_score(y_test, y_pred)
180
-
181
- return score, metric
182
-
183
- class GradioInterface:
184
- """Gradio interface for the data analysis agent"""
185
-
186
- def __init__(self):
187
- self.analyzer = None
188
- self.df = None
189
 
190
- DEFAULT_SYSTEM_PROMPT = """
191
- <DataScienceExpertFramework version="1.0">
192
- <Identity>
193
- <Description>
194
- You are an expert data scientist and analyst who combines technical precision with clear communication. You specialize in uncovering insights through advanced statistical analysis, machine learning, and data visualization.
195
- </Description>
196
- </Identity>
197
- <CoreCapabilities>
198
- <Analysis>
199
- <Capability>Advanced statistical analysis and hypothesis testing</Capability>
200
- <Capability>Machine learning model development and evaluation</Capability>
201
- <Capability>Data visualization and exploratory data analysis</Capability>
202
- <Capability>Pattern recognition and trend identification</Capability>
203
- <Capability>Feature engineering and selection</Capability>
204
- </Analysis>
205
- <Communication>
206
- <Style>Clear and precise technical explanations</Style>
207
- <Style>Business-oriented insights translation</Style>
208
- <Style>Visual representation of complex patterns</Style>
209
- </Communication>
210
- </CoreCapabilities>
211
- <AnalysisApproach>
212
- <Step>Data Quality Assessment</Step>
213
- <Step>Exploratory Data Analysis</Step>
214
- <Step>Statistical Testing</Step>
215
- <Step>Pattern Recognition</Step>
216
- <Step>Insight Generation</Step>
217
- <Step>Visualization Creation</Step>
218
- <Step>Recommendations Development</Step>
219
- </AnalysisApproach>
220
- <OutputGuidelines>
221
- <Format>
222
- <Section>Key Findings Summary</Section>
223
- <Section>Detailed Statistical Analysis</Section>
224
- <Section>Visualization Descriptions</Section>
225
- <Section>Actionable Recommendations</Section>
226
- </Format>
227
- <Standards>
228
- <Standard>Always explain statistical significance</Standard>
229
- <Standard>Provide context for numerical findings</Standard>
230
- <Standard>Highlight practical implications</Standard>
231
- <Standard>Address data limitations</Standard>
232
- </Standards>
233
- </OutputGuidelines>
234
- </DataScienceExpertFramework>
235
- """
236
-
237
- def create_interface(self):
238
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
239
- gr.Markdown("# 🔍 Intelligent Data Analysis Agent")
240
-
241
- with gr.Row():
242
- with gr.Column(scale=1):
243
- api_key = gr.Textbox(
244
- label="GPT-4o-mini API Key",
245
- type="password",
246
- placeholder="sk-..."
247
- )
248
- file_input = gr.File(
249
- label="Upload CSV file"
250
- )
251
-
252
- with gr.Accordion("⚙️ Advanced Settings", open=False):
253
- system_prompt = gr.TextArea(
254
- label="System Prompt",
255
- value=DEFAULT_SYSTEM_PROMPT,
256
- lines=8
257
- )
258
-
259
- with gr.Row():
260
- analysis_notes = gr.Textbox(
261
- label="Analysis Notes (Optional)",
262
- placeholder="Any specific analysis preferences...")
263
-
264
- with gr.Row():
265
- analyze_btn = gr.Button("Analyze Data")
266
- clear_btn = gr.Button("Clear")
267
-
268
- output_text = gr.Markdown()
269
- output_gallery = gr.Gallery()
270
-
271
- def analyze(api_key, file, notes, system_prompt):
272
- if not api_key or not file:
273
- return "Please provide both API key and data file.", None
274
-
275
- try:
276
- self.df = pd.read_csv(file.name)
277
- self.analyzer = DataAnalyzer(api_key)
278
-
279
- # Get AI suggestions for analysis
280
- prompt = f"Data columns: {list(self.df.columns)}\nUser notes: {notes}\nSuggest appropriate analyses and visualizations."
281
- ai_suggestions = self.analyzer.call_gpt4o_mini(prompt)
282
-
283
- # Perform analysis
284
- data_types = self.analyzer.analyze_data_types(self.df)
285
- stats_results = self.analyzer.perform_statistical_tests(self.df, data_types)
286
-
287
- # Create visualizations
288
- viz_paths = []
289
- for viz_type in ["correlation", "distribution", "boxplot"]:
290
- if data_types["numeric_cols"]:
291
- path = self.analyzer.create_visualization(
292
- self.df, viz_type, data_types["numeric_cols"]
293
- )
294
- viz_paths.append(path)
295
-
296
- # Generate summary
297
- summary = f"""
298
- ## Data Analysis Results
299
-
300
- ### AI Suggestions
301
- {ai_suggestions}
302
-
303
- ### Basic Statistics
304
- - Rows: {len(self.df)}
305
- - Columns: {len(self.df.columns)}
306
- - Missing Values: {sum(data_types['missing_values'].values())}
307
-
308
- ### Statistical Tests
309
- {self._format_stats_results(stats_results)}
310
- """
311
-
312
- return summary, viz_paths
313
-
314
- except Exception as e:
315
- return f"Error during analysis: {str(e)}", None
316
-
317
- analyze_btn.click(
318
- analyze,
319
- inputs=[api_key, file_input, analysis_notes, system_prompt],
320
- outputs=[output_text, output_gallery]
321
- )
322
-
323
- clear_btn.click(
324
- lambda: (None, None),
325
- outputs=[output_text, output_gallery]
326
- )
327
-
328
  return demo
329
-
330
- @staticmethod
331
- def _format_stats_results(results: Dict) -> str:
332
- """Format statistical results for display"""
333
- formatted = []
334
- for test_name, result in results.items():
335
- if "normality" in test_name:
336
- formatted.append(f"- {test_name}: {'Normal' if result['is_normal'] else 'Non-normal'} "
337
- f"(p={result['p_value']:.4f})")
338
- elif "chi2" in test_name:
339
- formatted.append(f"- {test_name}: {'Significant' if result['is_significant'] else 'Not significant'} "
340
- f"(p={result['p_value']:.4f})")
341
- return "\n".join(formatted)
342
 
343
  if __name__ == "__main__":
344
- interface = GradioInterface()
345
- demo = interface.create_interface()
346
- demo.launch(share=True)
 
1
  import os
 
2
  import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
+ from typing import Dict, List, Optional
8
+ import openai
9
  from dataclasses import dataclass
10
+ import plotly.express as px
11
+ from sklearn.preprocessing import StandardScaler
12
  from sklearn.model_selection import train_test_split
 
 
 
13
  import statsmodels.api as sm
 
 
 
14
 
15
+ # System prompt for data analysis
16
+ DATA_ANALYSIS_PROMPT = """
17
+ <DataScienceExpertFramework version="1.0">
18
+ <Identity>
19
+ <Description>You are an expert data scientist who combines technical precision with clear insights.</Description>
20
+ </Identity>
21
+ <CoreCapabilities>
22
+ <Analysis>
23
+ <Capability>Statistical analysis and hypothesis testing</Capability>
24
+ <Capability>Pattern recognition and insights</Capability>
25
+ <Capability>Data visualization recommendations</Capability>
26
+ </Analysis>
27
+ </CoreCapabilities>
28
+ <AnalysisApproach>
29
+ <Step>Assess data quality and structure</Step>
30
+ <Step>Identify key patterns and relationships</Step>
31
+ <Step>Perform statistical analysis</Step>
32
+ <Step>Generate visualizations</Step>
33
+ <Step>Provide actionable insights</Step>
34
+ </AnalysisApproach>
35
+ </DataScienceExpertFramework>
36
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def format_stats_results(results: Dict) -> str:
39
+ """Format statistical results for display"""
40
+ formatted = []
41
+ for test_name, result in results.items():
42
+ if "normality" in test_name:
43
+ formatted.append(f"- {test_name}: {'Normal' if result['is_normal'] else 'Non-normal'} "
44
+ f"(p={result['p_value']:.4f})")
45
+ elif "correlation" in test_name:
46
+ formatted.append(f"- {test_name}: {result['correlation']:.4f} "
47
+ f"(p={result['p_value']:.4f})")
48
+ return "\n".join(formatted)
 
 
 
 
 
 
 
 
 
49
 
50
+ def analyze_data(df: pd.DataFrame) -> Dict:
51
+ """Analyze dataframe and return statistics"""
52
+ analysis = {
53
+ "shape": df.shape,
54
+ "dtypes": df.dtypes.to_dict(),
55
+ "missing": df.isnull().sum().to_dict(),
56
+ "numeric_summary": df.describe().to_dict(),
57
+ "correlations": {}
58
+ }
59
+
60
+ # Calculate correlations for numeric columns
61
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
62
+ if len(numeric_cols) >= 2:
63
+ corr_matrix = df[numeric_cols].corr()
64
+ analysis["correlations"] = corr_matrix.to_dict()
65
+
66
+ return analysis
67
 
68
+ def create_visualizations(df: pd.DataFrame, save_dir: str = "figures") -> List[str]:
69
+ """Create and save visualizations"""
70
+ os.makedirs(save_dir, exist_ok=True)
71
+ paths = []
72
+
73
+ # Correlation heatmap
74
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
75
+ if len(numeric_cols) >= 2:
76
+ plt.figure(figsize=(10, 8))
77
+ sns.heatmap(df[numeric_cols].corr(), annot=True, cmap='coolwarm')
78
+ plt.title("Correlation Heatmap")
79
+ path = os.path.join(save_dir, "correlation_heatmap.png")
80
+ plt.savefig(path)
81
+ plt.close()
82
+ paths.append(path)
83
+
84
+ # Distribution plots for numeric columns
85
+ for col in numeric_cols[:5]: # Limit to first 5 columns
86
  plt.figure(figsize=(10, 6))
87
+ sns.histplot(df[col], kde=True)
88
+ plt.title(f"Distribution of {col}")
89
+ path = os.path.join(save_dir, f"dist_{col}.png")
90
+ plt.savefig(path)
 
 
 
 
 
 
 
 
 
 
91
  plt.close()
92
+ paths.append(path)
93
+
94
+ return paths
95
 
96
+ def chat_with_data_scientist(message: str, history: List, api_key: str, df: Optional[pd.DataFrame] = None) -> List:
97
+ """Chat with GPT-4o-mini about data analysis"""
98
+ if not api_key:
99
+ return history + [
100
+ ("Please provide an API key to continue.", None)
101
+ ]
102
+
103
+ if df is None:
104
+ return history + [
105
+ ("Please upload a CSV file to analyze.", None)
106
+ ]
107
+
108
+ try:
109
+ client = openai.OpenAI(api_key=api_key)
110
+
111
+ # Create analysis summary
112
+ analysis = analyze_data(df)
113
+ analysis_text = f"""
114
+ Dataset Shape: {analysis['shape']}
115
+ Missing Values: {sum(analysis['missing'].values())}
116
+ Numeric Columns: {list(analysis['numeric_summary'].keys())}
117
+ """
118
 
119
+ messages = [
120
+ {"role": "system", "content": DATA_ANALYSIS_PROMPT},
121
+ {"role": "system", "content": f"Analysis Context:\n{analysis_text}"},
122
+ {"role": "user", "content": message}
123
+ ]
 
 
 
 
124
 
125
+ response = client.chat.completions.create(
126
+ model="gpt-4o-mini",
127
+ messages=messages,
128
+ max_tokens=500
129
+ )
 
 
 
 
 
 
130
 
131
+ return history + [
132
+ (message, response.choices[0].message.content)
133
+ ]
134
+
135
+ except Exception as e:
136
+ return history + [
137
+ (message, f"Error: {str(e)}")
138
+ ]
139
 
140
+ def create_demo():
141
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
142
+ gr.Markdown("# 🔬 Data Science Expert")
 
143
 
144
+ with gr.Row():
145
+ with gr.Column():
146
+ api_key = gr.Textbox(
147
+ label="GPT-4o-mini API Key",
148
+ placeholder="sk-...",
149
+ type="password"
150
+ )
151
+ file_input = gr.File(
152
+ label="Upload CSV file",
153
+ file_types=[".csv"]
154
+ )
155
+ system_prompt = gr.Textbox(
156
+ label="System Prompt",
157
+ value=DATA_ANALYSIS_PROMPT,
158
+ lines=5
159
+ )
160
+
161
+ with gr.Column():
162
+ chat = gr.Chatbot(label="Analysis Chat")
163
+ msg = gr.Textbox(
164
+ label="Ask about your data",
165
+ placeholder="What insights can you find in this dataset?"
166
+ )
167
+ clear = gr.Button("Clear")
168
 
169
+ # Store DataFrame in state
170
+ df_state = gr.State(None)
 
 
171
 
172
+ def process_file(file):
173
+ if file is None:
174
+ return None
175
+ return pd.read_csv(file.name)
 
176
 
177
+ file_input.change(
178
+ process_file,
179
+ inputs=[file_input],
180
+ outputs=[df_state]
181
+ )
 
 
 
 
 
 
182
 
183
+ msg.submit(
184
+ chat_with_data_scientist,
185
+ inputs=[msg, chat, api_key, df_state],
186
+ outputs=[chat]
187
+ )
188
 
189
+ clear.click(lambda: None, None, chat)
 
 
 
 
 
 
 
 
 
 
 
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  return demo
192
+
193
+ demo = create_demo()
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  if __name__ == "__main__":
196
+ demo.launch()
197
+ else:
198
+ demo.launch(show_api=False)