jzou19950715 commited on
Commit
e279441
·
verified ·
1 Parent(s): 4722ac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -239
app.py CHANGED
@@ -2,249 +2,261 @@ import os
2
  import requests
3
  import gradio as gr
4
  import pandas as pd
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
- import numpy as np
8
-
 
9
  from sklearn.model_selection import train_test_split
10
- from sklearn.linear_model import LogisticRegression
11
- from sklearn.preprocessing import LabelEncoder
12
-
13
- ##############################################################################
14
- # GPT-4o-mini Placeholder - Adjust for your real endpoint & JSON
15
- ##############################################################################
16
- def call_gpt4o_mini(api_key, user_prompt):
17
- """
18
- Hypothetical call to GPT-4o-mini with an sk-... style token.
19
- Example endpoint: https://api.gpt4o-mini.com/v1/chat
20
- - Adjust JSON structure and keys to your actual service spec.
21
- """
22
- if not api_key or not api_key.startswith("sk-"):
23
- return "Please provide a valid GPT-4o-mini token (sk-...)."
24
-
25
- url = "https://api.gpt4o-mini.com/v1/chat" # <--- Replace with real endpoint
26
- headers = {
27
- "Authorization": f"Bearer {api_key}",
28
- "Content-Type": "application/json",
29
- }
30
- payload = {
31
- "prompt": user_prompt,
32
- "max_tokens": 128, # limit tokens for cost
33
- "temperature": 0.7,
34
- }
35
- try:
36
- response = requests.post(url, json=payload, headers=headers, timeout=10)
37
- response.raise_for_status()
38
- data = response.json()
39
- # Suppose the text is in data["choices"][0]["text"] (adjust if needed)
40
- return data["choices"][0]["text"]
41
- except Exception as e:
42
- return f"Error calling GPT-4o-mini: {str(e)}"
43
-
44
-
45
- ##############################################################################
46
- # Local Data Analysis
47
- ##############################################################################
48
- def extended_analysis(df):
49
- """
50
- Does correlation heatmap, bar plot for 'Career', and logistic regression
51
- if 'Career' has multiple categories. Returns (list_of_image_paths, info_string).
52
- """
53
- output_paths = []
54
- numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
55
-
56
- # 1) Correlation Heatmap
57
- if len(numeric_cols) > 1:
58
- corr = df[numeric_cols].corr()
59
- plt.figure(figsize=(8, 6))
60
- sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
61
- plt.title("Correlation Heatmap")
62
- heatmap_path = "heatmap.png"
63
- plt.savefig(heatmap_path)
64
- plt.close()
65
- output_paths.append(heatmap_path)
66
-
67
- # 2) Bar Plot for 'Career'
68
- if "Career" in df.columns:
69
- plt.figure(figsize=(8, 5))
70
- career_counts = df["Career"].value_counts()
71
- sns.barplot(x=career_counts.index, y=career_counts.values)
72
- plt.title("Distribution of Careers")
73
- plt.xlabel("Career")
74
- plt.ylabel("Count")
75
- plt.xticks(rotation=45, ha="right")
76
- barplot_path = "career_distribution.png"
77
- plt.savefig(barplot_path)
78
- plt.close()
79
- output_paths.append(barplot_path)
80
-
81
- # 3) Simple Logistic Regression
82
- if "Career" in df.columns and len(numeric_cols) > 0:
83
- le = LabelEncoder()
84
- df["Career_encoded"] = le.fit_transform(df["Career"])
85
- X = df[numeric_cols].fillna(0)
86
- y = df["Career_encoded"]
87
- if len(np.unique(y)) > 1:
88
- X_train, X_test, y_train, y_test = train_test_split(
89
- X, y, test_size=0.2, random_state=42
90
  )
91
- model = LogisticRegression(max_iter=1000)
92
- model.fit(X_train, y_train)
93
- score = model.score(X_test, y_test)
94
- accuracy_info = f"Logistic Regression accuracy: {score:.2f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  else:
96
- accuracy_info = "Only one category in 'Career'; no classification performed."
97
- else:
98
- accuracy_info = "No 'Career' column or insufficient numeric columns for classification."
99
-
100
- return output_paths, accuracy_info
101
-
102
-
103
- ##############################################################################
104
- # Main Chat/Analysis Function
105
- ##############################################################################
106
- def handle_chat(user_message, df, chat_history, api_key):
107
- """
108
- - If df is None, prompt user to upload a CSV.
109
- - Else, do local analysis and optionally call GPT-4o-mini for suggestions.
110
- - Update the chat_history with role='user' or role='assistant' messages.
111
- - Return new chat_history in 'messages' format for the Gradio Chatbot (type='messages').
112
- """
113
- if df is None:
114
- chat_history.append({"role": "assistant", "content": "Please upload a CSV first."})
115
- return chat_history
116
-
117
- # Summarize data
118
- numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
119
- cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
120
- summary = (
121
- f"Rows: {df.shape[0]}, Columns: {df.shape[1]}\n"
122
- f"Numeric: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
123
- f"Categorical: {', '.join(cat_cols) if cat_cols else 'None'}"
124
- )
125
-
126
- # Always show user message in chat
127
- chat_history.append({"role": "user", "content": user_message})
128
-
129
- # Possibly call GPT-4o-mini for suggestions
130
- gpt_reply = ""
131
- if api_key:
132
- prompt = f"Data Summary:\n{summary}\nUser Query: {user_message}"
133
- gpt_reply = call_gpt4o_mini(api_key, prompt)
134
-
135
- # Build the reply text (local summary + LLM suggestions)
136
- reply_text = f"**Data Summary**:\n{summary}"
137
- if gpt_reply:
138
- reply_text += f"\n\n**GPT-4o-mini**: {gpt_reply}"
139
-
140
- # Check if user wants extended analysis
141
- triggers = ["sample analysis", "extended analysis", "advanced analysis", "run analysis", "visualize", "plot"]
142
- if any(t in user_message.lower() for t in triggers):
143
- # Perform extended analysis
144
- image_paths, info = extended_analysis(df)
145
- if info:
146
- reply_text += f"\n\n**Analysis Info**: {info}"
147
- # Add images to chat
148
- chat_history.append({"role": "assistant", "content": reply_text})
149
- # Return images as separate chat items
150
- for path in image_paths:
151
- chat_history.append({"role": "assistant", "content": None, "image": path})
152
- return chat_history
153
-
154
- # If no extended analysis triggered, just add the text
155
- chat_history.append({"role": "assistant", "content": reply_text})
156
- return chat_history
157
-
158
-
159
- ##############################################################################
160
- # Gradio Interface
161
- ##############################################################################
162
- def create_demo():
163
- with gr.Blocks() as demo:
164
- # State: holds the DataFrame and the chat messages
165
- df_state = gr.State(None)
166
- chat_state = gr.State([]) # store messages as list of dicts: [{"role": "...", "content": "..."}]
167
-
168
- gr.Markdown("## GPT-4o-mini Data Analysis Assistant (Chat)")
169
- gr.Markdown(
170
- """
171
- 1. Enter your GPT-4o-mini token (`sk-...`) if you want AI suggestions.
172
- 2. Upload a CSV file.
173
- 3. Ask questions or request "sample analysis", "visualize", etc.
174
- 4. Images are displayed in the chat when relevant.
175
- """
176
- )
177
-
178
- api_key_box = gr.Textbox(label="GPT-4o-mini Token (sk-...)", placeholder="Optional: sk-xxxx")
179
- file_input = gr.File(label="Upload CSV", file_types=[".csv"])
180
-
181
- # Chatbot in "messages" format to fix the deprecation warning
182
- chatbot = gr.Chatbot(label="Chat Output", type="messages")
183
-
184
- user_message = gr.Textbox(label="Your Message", placeholder="Ask about your data...")
185
-
186
- def upload_csv(file):
187
- """
188
- On file upload, load the DataFrame into df_state and reset the chat if needed.
189
- """
190
- if file is None:
191
- return None
192
- df = pd.read_csv(file.name)
193
- return df
194
-
195
- file_input.change(fn=upload_csv, inputs=file_input, outputs=df_state)
196
-
197
- def on_user_message(message, df, chat_history, api_key):
198
- """
199
- Called when user sends a message. Handle chat + analysis. Return new chat messages.
200
- """
201
- if not message.strip():
202
- return chat_history # ignore empty
203
- updated_history = handle_chat(message, df, chat_history, api_key)
204
- return updated_history
205
-
206
- user_message.submit(
207
- fn=on_user_message,
208
- inputs=[user_message, df_state, chat_state, api_key_box],
209
- outputs=chat_state
210
- ).then(
211
- # After updating chat_state, reflect it in the chatbot
212
- fn=lambda messages: messages,
213
- inputs=chat_state,
214
- outputs=chatbot
215
- ).then(
216
- fn=lambda: "",
217
- outputs=user_message
218
- )
219
-
220
- # Button to send message
221
- send_btn = gr.Button("Send")
222
- send_btn.click(
223
- fn=on_user_message,
224
- inputs=[user_message, df_state, chat_state, api_key_box],
225
- outputs=chat_state
226
- ).then(
227
- fn=lambda messages: messages,
228
- inputs=chat_state,
229
- outputs=chatbot
230
- ).then(
231
- fn=lambda: "",
232
- outputs=user_message
233
- )
234
-
235
- # Clear chat button
236
- clear_btn = gr.Button("Clear Chat")
237
- def clear_chat():
238
- return [], []
239
- clear_btn.click(
240
- fn=clear_chat,
241
- inputs=[],
242
- outputs=[chat_state, chatbot]
243
- )
244
-
245
- return demo
246
-
247
- demo = create_demo()
248
 
249
  if __name__ == "__main__":
250
- demo.launch(share=True)
 
 
 
2
  import requests
3
  import gradio as gr
4
  import pandas as pd
5
+ import numpy as np
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
+ from typing import Dict, List, Tuple, Optional
9
+ from dataclasses import dataclass
10
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
11
  from sklearn.model_selection import train_test_split
12
+ from sklearn.metrics import mean_squared_error, r2_score, accuracy_score
13
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
14
+ from sklearn.impute import SimpleImputer
15
+ import statsmodels.api as sm
16
+ import plotly.express as px
17
+ import plotly.graph_objects as go
18
+ from scipy import stats
19
+
20
+ @dataclass
21
+ class AnalysisConfig:
22
+ """Configuration for analysis parameters"""
23
+ max_iterations: int = 5
24
+ min_samples_for_analysis: int = 30
25
+ correlation_threshold: float = 0.7
26
+ max_categories_for_viz: int = 10
27
+ significance_level: float = 0.05
28
+
29
+ class DataAnalyzer:
30
+ """Intelligent data analysis agent that determines appropriate visualizations and analyses"""
31
+
32
+ def __init__(self, api_key: str):
33
+ self.api_key = api_key
34
+ self.config = AnalysisConfig()
35
+ self.current_iteration = 0
36
+ self.analysis_results = []
37
+
38
+ def call_gpt4o_mini(self, prompt: str) -> str:
39
+ """Call GPT-4o-mini API with proper error handling"""
40
+ try:
41
+ headers = {
42
+ "Authorization": f"Bearer {self.api_key}",
43
+ "Content-Type": "application/json"
44
+ }
45
+ response = requests.post(
46
+ "https://api.gpt4o-mini.example.com/v1/chat", # Replace with actual endpoint
47
+ json={"prompt": prompt, "max_tokens": 500, "temperature": 0.7},
48
+ headers=headers,
49
+ timeout=15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
+ response.raise_for_status()
52
+ return response.json()["choices"][0]["text"]
53
+ except Exception as e:
54
+ return f"API Error: {str(e)}"
55
+
56
+ def analyze_data_types(self, df: pd.DataFrame) -> Dict:
57
+ """Analyze data types and basic statistics of the DataFrame"""
58
+ analysis = {
59
+ "numeric_cols": df.select_dtypes(include=['int64', 'float64']).columns.tolist(),
60
+ "categorical_cols": df.select_dtypes(include=['object', 'category']).columns.tolist(),
61
+ "temporal_cols": df.select_dtypes(include=['datetime64']).columns.tolist(),
62
+ "missing_values": df.isnull().sum().to_dict(),
63
+ "unique_counts": df.nunique().to_dict()
64
+ }
65
+ return analysis
66
+
67
+ def create_visualization(self, df: pd.DataFrame, viz_type: str, columns: List[str]) -> str:
68
+ """Create and save visualization based on data types and relationships"""
69
+ plt.figure(figsize=(10, 6))
70
+
71
+ if viz_type == "correlation":
72
+ sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm')
73
+ plt.title("Correlation Matrix")
74
+ elif viz_type == "distribution":
75
+ for col in columns:
76
+ sns.histplot(data=df, x=col, kde=True)
77
+ plt.title(f"Distribution of {col}")
78
+ elif viz_type == "boxplot":
79
+ sns.boxplot(data=df[columns])
80
+ plt.title("Box Plot of Numeric Variables")
81
+
82
+ output_path = f"viz_{self.current_iteration}.png"
83
+ plt.savefig(output_path)
84
+ plt.close()
85
+ return output_path
86
+
87
+ def perform_statistical_tests(self, df: pd.DataFrame, data_types: Dict) -> Dict:
88
+ """Perform relevant statistical tests based on data types"""
89
+ results = {}
90
+
91
+ # Normality tests for numeric columns
92
+ for col in data_types["numeric_cols"]:
93
+ if len(df[col].dropna()) > 3:
94
+ stat, p_value = stats.normaltest(df[col].dropna())
95
+ results[f"normality_{col}"] = {
96
+ "statistic": stat,
97
+ "p_value": p_value,
98
+ "is_normal": p_value > self.config.significance_level
99
+ }
100
+
101
+ # Chi-square tests for categorical columns
102
+ for col1 in data_types["categorical_cols"]:
103
+ for col2 in data_types["categorical_cols"]:
104
+ if col1 < col2:
105
+ contingency = pd.crosstab(df[col1], df[col2])
106
+ chi2, p_value, _, _ = stats.chi2_contingency(contingency)
107
+ results[f"chi2_{col1}_{col2}"] = {
108
+ "statistic": chi2,
109
+ "p_value": p_value,
110
+ "is_significant": p_value < self.config.significance_level
111
+ }
112
+
113
+ return results
114
+
115
+ def train_predictive_model(self, df: pd.DataFrame, target_col: str) -> Tuple[float, str]:
116
+ """Train and evaluate a predictive model based on data characteristics"""
117
+ X = df.drop(columns=[target_col])
118
+ y = df[target_col]
119
+
120
+ # Preprocessing
121
+ numeric_transformer = Pipeline([
122
+ ('imputer', SimpleImputer(strategy='median')),
123
+ ('scaler', StandardScaler())
124
+ ])
125
+
126
+ categorical_transformer = Pipeline([
127
+ ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
128
+ ('onehot', OneHotEncoder(handle_unknown='ignore'))
129
+ ])
130
+
131
+ preprocessor = ColumnTransformer(
132
+ transformers=[
133
+ ('num', numeric_transformer, X.select_dtypes(include=['int64', 'float64']).columns),
134
+ ('cat', categorical_transformer, X.select_dtypes(include=['object']).columns)
135
+ ])
136
+
137
+ if len(np.unique(y)) <= 5: # Classification
138
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
139
+ metric = 'accuracy'
140
+ else: # Regression
141
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
142
+ metric = 'r2'
143
+
144
+ pipeline = Pipeline([
145
+ ('preprocessor', preprocessor),
146
+ ('model', model)
147
+ ])
148
+
149
+ # Train and evaluate
150
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
151
+ pipeline.fit(X_train, y_train)
152
+ y_pred = pipeline.predict(X_test)
153
+
154
+ if metric == 'accuracy':
155
+ score = accuracy_score(y_test, y_pred)
156
  else:
157
+ score = r2_score(y_test, y_pred)
158
+
159
+ return score, metric
160
+
161
+ class GradioInterface:
162
+ """Gradio interface for the data analysis agent"""
163
+
164
+ def __init__(self):
165
+ self.analyzer = None
166
+ self.df = None
167
+
168
+ def create_interface(self):
169
+ with gr.Blocks() as demo:
170
+ gr.Markdown("# Intelligent Data Analysis Agent")
171
+
172
+ with gr.Row():
173
+ api_key = gr.Textbox(label="GPT-4o-mini API Key", type="password")
174
+ file_input = gr.File(label="Upload CSV file")
175
+
176
+ with gr.Row():
177
+ analysis_notes = gr.Textbox(label="Analysis Notes (Optional)",
178
+ placeholder="Any specific analysis preferences...")
179
+
180
+ with gr.Row():
181
+ analyze_btn = gr.Button("Analyze Data")
182
+ clear_btn = gr.Button("Clear")
183
+
184
+ output_text = gr.Markdown()
185
+ output_gallery = gr.Gallery()
186
+
187
+ def analyze(api_key, file, notes):
188
+ if not api_key or not file:
189
+ return "Please provide both API key and data file.", None
190
+
191
+ try:
192
+ self.df = pd.read_csv(file.name)
193
+ self.analyzer = DataAnalyzer(api_key)
194
+
195
+ # Get AI suggestions for analysis
196
+ prompt = f"Data columns: {list(self.df.columns)}\nUser notes: {notes}\nSuggest appropriate analyses and visualizations."
197
+ ai_suggestions = self.analyzer.call_gpt4o_mini(prompt)
198
+
199
+ # Perform analysis
200
+ data_types = self.analyzer.analyze_data_types(self.df)
201
+ stats_results = self.analyzer.perform_statistical_tests(self.df, data_types)
202
+
203
+ # Create visualizations
204
+ viz_paths = []
205
+ for viz_type in ["correlation", "distribution", "boxplot"]:
206
+ if data_types["numeric_cols"]:
207
+ path = self.analyzer.create_visualization(
208
+ self.df, viz_type, data_types["numeric_cols"]
209
+ )
210
+ viz_paths.append(path)
211
+
212
+ # Generate summary
213
+ summary = f"""
214
+ ## Data Analysis Results
215
+
216
+ ### AI Suggestions
217
+ {ai_suggestions}
218
+
219
+ ### Basic Statistics
220
+ - Rows: {len(self.df)}
221
+ - Columns: {len(self.df.columns)}
222
+ - Missing Values: {sum(data_types['missing_values'].values())}
223
+
224
+ ### Statistical Tests
225
+ {self._format_stats_results(stats_results)}
226
+ """
227
+
228
+ return summary, viz_paths
229
+
230
+ except Exception as e:
231
+ return f"Error during analysis: {str(e)}", None
232
+
233
+ analyze_btn.click(
234
+ analyze,
235
+ inputs=[api_key, file_input, analysis_notes],
236
+ outputs=[output_text, output_gallery]
237
+ )
238
+
239
+ clear_btn.click(
240
+ lambda: (None, None),
241
+ outputs=[output_text, output_gallery]
242
+ )
243
+
244
+ return demo
245
+
246
+ @staticmethod
247
+ def _format_stats_results(results: Dict) -> str:
248
+ """Format statistical results for display"""
249
+ formatted = []
250
+ for test_name, result in results.items():
251
+ if "normality" in test_name:
252
+ formatted.append(f"- {test_name}: {'Normal' if result['is_normal'] else 'Non-normal'} "
253
+ f"(p={result['p_value']:.4f})")
254
+ elif "chi2" in test_name:
255
+ formatted.append(f"- {test_name}: {'Significant' if result['is_significant'] else 'Not significant'} "
256
+ f"(p={result['p_value']:.4f})")
257
+ return "\n".join(formatted)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  if __name__ == "__main__":
260
+ interface = GradioInterface()
261
+ demo = interface.create_interface()
262
+ demo.launch(share=True)