jzou19950715 commited on
Commit
4722ac6
·
verified ·
1 Parent(s): 57afec9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -261
app.py CHANGED
@@ -5,291 +5,246 @@ import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import numpy as np
 
8
  from sklearn.model_selection import train_test_split
9
  from sklearn.linear_model import LogisticRegression
10
  from sklearn.preprocessing import LabelEncoder
11
- import tempfile
12
- import shutil
13
-
14
- # Create a temporary directory for plot files
15
- TEMP_DIR = tempfile.mkdtemp()
16
-
17
- def cleanup_temp_files():
18
- """Clean up temporary files when the application exits"""
19
- try:
20
- shutil.rmtree(TEMP_DIR)
21
- except Exception as e:
22
- print(f"Error cleaning up temporary files: {e}")
23
-
24
- def get_temp_path(filename):
25
- """Generate a path for temporary files"""
26
- return os.path.join(TEMP_DIR, filename)
27
-
28
- def call_gpt4o_mini(api_key, context, user_prompt):
29
- """Enhanced GPT-4o-mini call with better error handling"""
30
- if not api_key or not api_key.startswith('sk-'):
31
- return "Please provide a valid API key (should start with 'sk-')"
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
- url = os.getenv('GPT4O_MINI_API_URL', 'https://api.openai.com/v1/chat/completions')
35
- headers = {
36
- "Authorization": f"Bearer {api_key}",
37
- "Content-Type": "application/json"
38
- }
39
-
40
- messages = [
41
- {"role": "system", "content": "You are a data analysis assistant. Analyze the provided data and context."},
42
- {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {user_prompt}"}
43
- ]
44
-
45
- payload = {
46
- "model": "gpt-4",
47
- "messages": messages,
48
- "max_tokens": 500,
49
- "temperature": 0.7
50
- }
51
-
52
  response = requests.post(url, json=payload, headers=headers, timeout=10)
53
- if response.status_code == 401:
54
- return "Invalid API key. Please check your credentials."
55
  response.raise_for_status()
56
- return response.json()["choices"][0]["message"]["content"]
57
- except requests.exceptions.RequestException as e:
58
- return f"API Error: {str(e)}"
59
  except Exception as e:
60
- return f"Error: {str(e)}"
61
 
 
 
 
 
62
  def extended_analysis(df):
63
- """Perform advanced analysis with improved error handling"""
 
 
 
64
  output_paths = []
65
  numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
66
-
67
- try:
68
- # 1. Correlation Heatmap
69
- if len(numeric_cols) > 1:
70
- plt.figure(figsize=(12, 8))
71
- corr = df[numeric_cols].corr()
72
- mask = np.triu(np.ones_like(corr, dtype=bool))
73
- sns.heatmap(corr, mask=mask, annot=True, cmap="coolwarm", fmt=".2f",
74
- square=True, linewidths=.5)
75
- plt.title("Correlation Heatmap of Numeric Features")
76
- plt.tight_layout()
77
- heatmap_path = get_temp_path("heatmap.png")
78
- plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
79
- plt.close()
80
- output_paths.append(heatmap_path)
81
-
82
- # 2. Career Distribution
83
- if "Career" in df.columns:
84
- plt.figure(figsize=(12, 6))
85
- career_counts = df["Career"].value_counts()
86
- sns.barplot(x=career_counts.index, y=career_counts.values)
87
- plt.title("Distribution of Careers")
88
- plt.xticks(rotation=45, ha='right')
89
- plt.xlabel("Career")
90
- plt.ylabel("Count")
91
- plt.tight_layout()
92
- barplot_path = get_temp_path("career_distribution.png")
93
- plt.savefig(barplot_path, dpi=300, bbox_inches='tight')
94
- plt.close()
95
- output_paths.append(barplot_path)
96
-
97
- # 3. Box Plots for Scores
98
- score_columns = [col for col in df.columns if 'score' in col.lower() or 'aptitude' in col.lower()]
99
- if score_columns:
100
- plt.figure(figsize=(15, 8))
101
- df[score_columns].boxplot()
102
- plt.title("Distribution of Scores and Aptitudes")
103
- plt.xticks(rotation=45, ha='right')
104
- plt.ylabel("Score")
105
- plt.grid(True, alpha=0.3)
106
- plt.tight_layout()
107
- boxplot_path = get_temp_path("scores_distribution.png")
108
- plt.savefig(boxplot_path, dpi=300, bbox_inches='tight')
109
- plt.close()
110
- output_paths.append(boxplot_path)
111
-
112
- # 4. Machine Learning Analysis
113
- if "Career" in df.columns and len(numeric_cols) > 0:
114
- le = LabelEncoder()
115
- df["Career_encoded"] = le.fit_transform(df["Career"])
116
- X = df[numeric_cols].fillna(df[numeric_cols].mean())
117
- y = df["Career_encoded"]
118
-
119
- if len(np.unique(y)) > 1:
120
- X_train, X_test, y_train, y_test = train_test_split(
121
- X, y, test_size=0.2, random_state=42
122
- )
123
- model = LogisticRegression(max_iter=1000)
124
- model.fit(X_train, y_train)
125
- score = model.score(X_test, y_test)
126
-
127
- # Feature importance visualization
128
- plt.figure(figsize=(10, 6))
129
- importance = pd.DataFrame({
130
- 'feature': numeric_cols,
131
- 'importance': np.abs(model.coef_[0])
132
- }).sort_values('importance', ascending=True)
133
-
134
- sns.barplot(data=importance, x='importance', y='feature')
135
- plt.title("Feature Importance in Career Prediction")
136
- plt.tight_layout()
137
- importance_path = get_temp_path("feature_importance.png")
138
- plt.savefig(importance_path, dpi=300, bbox_inches='tight')
139
- plt.close()
140
- output_paths.append(importance_path)
141
-
142
- return output_paths, f"Model accuracy: {score:.2f}"
143
- else:
144
- return output_paths, "Insufficient unique careers for classification"
145
- return output_paths, "Analysis completed successfully"
146
-
147
- except Exception as e:
148
- return output_paths, f"Error during analysis: {str(e)}"
149
-
150
- class ChatHistory:
151
- """Simple chat history manager without Gradio component inheritance"""
152
- def __init__(self):
153
- self.messages = []
154
- self.data_summary = ""
155
- self.current_df = None
156
-
157
- def add_message(self, role, content):
158
- self.messages.append({"role": role, "content": content})
159
- if len(self.messages) > 10: # Keep last 10 messages
160
- self.messages.pop(0)
161
-
162
- def get_context(self):
163
- return "\n".join([f"{m['role']}: {m['content']}" for m in self.messages[-5:]])
164
-
165
- def set_data_summary(self, summary):
166
- self.data_summary = summary
167
-
168
- def get_full_context(self):
169
- return f"Data Summary:\n{self.data_summary}\n\nChat History:\n{self.get_context()}"
170
-
171
- def clear(self):
172
- self.messages = []
173
- self.data_summary = ""
174
- self.current_df = None
175
-
176
- def analyze_and_visualize(file, message, history, api_key, chat_history):
177
- """Main function for data analysis with improved state management"""
178
- if not file and chat_history.current_df is None:
179
- return history + [(message, "Please upload a CSV file first.")], chat_history
180
 
181
- try:
182
- # Load data if new file uploaded
183
- if file:
184
- df = pd.read_csv(file.name)
185
- chat_history.current_df = df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  else:
187
- df = chat_history.current_df
188
-
189
- numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
190
- cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
191
-
192
- summary = (
193
- f"📊 Data Summary:\n"
194
- f"• Rows: {df.shape[0]}, Columns: {df.shape[1]}\n"
195
- f"• Numeric columns: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
196
- f"• Categorical columns: {', '.join(cat_cols) if cat_cols else 'None'}\n"
197
- f"\nDescriptive Statistics:\n{df[numeric_cols].describe().round(2).to_string()}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
- # Update chat history with new data summary if file was uploaded
201
- if file:
202
- chat_history.set_data_summary(summary)
203
 
204
- # Add user message to context
205
- chat_history.add_message("user", message)
206
-
207
- # Get GPT-4o-mini insights if API key provided
208
- if api_key:
209
- gpt_response = call_gpt4o_mini(
210
- api_key,
211
- chat_history.get_full_context(),
212
- message
213
- )
214
- chat_history.add_message("assistant", gpt_response)
215
- response_text = f"{summary}\n\n🤖 AI Insights:\n{gpt_response}"
216
- else:
217
- response_text = f"{summary}\n\nNote: Add an API key for AI-powered insights."
218
-
219
- # Generate visualizations based on user message
220
- viz_triggers = ["visualize", "plot", "show", "graph", "analyze", "distribution"]
221
- if any(trigger in message.lower() for trigger in viz_triggers):
222
- analysis_paths, analysis_info = extended_analysis(df)
223
- if analysis_info:
224
- response_text += f"\n\n📈 Analysis Results:\n{analysis_info}"
225
-
226
- chat_content = [(message, response_text)]
227
- for path in analysis_paths:
228
- chat_content.append((None, (path,)))
229
- return history + chat_content, chat_history
230
-
231
- return history + [(message, response_text)], chat_history
232
 
233
- except Exception as e:
234
- error_msg = f"Error processing request: {str(e)}"
235
- chat_history.add_message("system", error_msg)
236
- return history + [(message, error_msg)], chat_history
237
 
238
- def create_demo():
239
- chat_history = ChatHistory()
240
-
241
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
242
- gr.Markdown("# Interactive Data Analysis Assistant")
243
- gr.Markdown("""
244
- **Instructions**:
245
- 1. (Optional) Enter your API key for AI-powered insights
246
- 2. Upload your CSV file
247
- 3. Ask questions about your data or request visualizations
248
- 4. Type keywords like 'visualize', 'plot', 'analyze' to generate charts
249
- """)
250
-
251
- with gr.Row():
252
- with gr.Column():
253
- api_key = gr.Textbox(
254
- label="API Key (Optional)",
255
- placeholder="Enter your API key here",
256
- type="password"
257
- )
258
- file_input = gr.File(label="Upload CSV", file_types=[".csv"])
259
- msg = gr.Textbox(
260
- label="Ask about your data",
261
- placeholder="e.g., 'Show me the distribution of scores' or 'What insights can you find?'"
262
- )
263
- with gr.Row():
264
- send_btn = gr.Button("Send")
265
- clear_btn = gr.Button("Clear Chat")
266
-
267
- with gr.Column():
268
- chatbot = gr.Chatbot(height=600)
269
-
270
- msg.submit(
271
- fn=analyze_and_visualize,
272
- inputs=[file_input, msg, chatbot, api_key, chat_history],
273
- outputs=[chatbot, chat_history]
274
- ).then(lambda: "", None, [msg])
275
 
 
 
276
  send_btn.click(
277
- fn=analyze_and_visualize,
278
- inputs=[file_input, msg, chatbot, api_key, chat_history],
279
- outputs=[chatbot, chat_history]
280
- ).then(lambda: "", None, [msg])
 
 
 
 
 
 
 
281
 
 
 
 
 
282
  clear_btn.click(
283
- fn=lambda: ([], ChatHistory()),
284
- outputs=[chatbot, chat_history]
 
285
  )
286
 
287
- demo.queue()
288
- return demo
 
289
 
290
  if __name__ == "__main__":
291
- try:
292
- demo = create_demo()
293
- demo.launch()
294
- finally:
295
- cleanup_temp_files()
 
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
  import numpy as np
8
+
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.linear_model import LogisticRegression
11
  from sklearn.preprocessing import LabelEncoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ ##############################################################################
14
+ # GPT-4o-mini Placeholder - Adjust for your real endpoint & JSON
15
+ ##############################################################################
16
+ def call_gpt4o_mini(api_key, user_prompt):
17
+ """
18
+ Hypothetical call to GPT-4o-mini with an sk-... style token.
19
+ Example endpoint: https://api.gpt4o-mini.com/v1/chat
20
+ - Adjust JSON structure and keys to your actual service spec.
21
+ """
22
+ if not api_key or not api_key.startswith("sk-"):
23
+ return "Please provide a valid GPT-4o-mini token (sk-...)."
24
+
25
+ url = "https://api.gpt4o-mini.com/v1/chat" # <--- Replace with real endpoint
26
+ headers = {
27
+ "Authorization": f"Bearer {api_key}",
28
+ "Content-Type": "application/json",
29
+ }
30
+ payload = {
31
+ "prompt": user_prompt,
32
+ "max_tokens": 128, # limit tokens for cost
33
+ "temperature": 0.7,
34
+ }
35
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  response = requests.post(url, json=payload, headers=headers, timeout=10)
 
 
37
  response.raise_for_status()
38
+ data = response.json()
39
+ # Suppose the text is in data["choices"][0]["text"] (adjust if needed)
40
+ return data["choices"][0]["text"]
41
  except Exception as e:
42
+ return f"Error calling GPT-4o-mini: {str(e)}"
43
 
44
+
45
+ ##############################################################################
46
+ # Local Data Analysis
47
+ ##############################################################################
48
  def extended_analysis(df):
49
+ """
50
+ Does correlation heatmap, bar plot for 'Career', and logistic regression
51
+ if 'Career' has multiple categories. Returns (list_of_image_paths, info_string).
52
+ """
53
  output_paths = []
54
  numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # 1) Correlation Heatmap
57
+ if len(numeric_cols) > 1:
58
+ corr = df[numeric_cols].corr()
59
+ plt.figure(figsize=(8, 6))
60
+ sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
61
+ plt.title("Correlation Heatmap")
62
+ heatmap_path = "heatmap.png"
63
+ plt.savefig(heatmap_path)
64
+ plt.close()
65
+ output_paths.append(heatmap_path)
66
+
67
+ # 2) Bar Plot for 'Career'
68
+ if "Career" in df.columns:
69
+ plt.figure(figsize=(8, 5))
70
+ career_counts = df["Career"].value_counts()
71
+ sns.barplot(x=career_counts.index, y=career_counts.values)
72
+ plt.title("Distribution of Careers")
73
+ plt.xlabel("Career")
74
+ plt.ylabel("Count")
75
+ plt.xticks(rotation=45, ha="right")
76
+ barplot_path = "career_distribution.png"
77
+ plt.savefig(barplot_path)
78
+ plt.close()
79
+ output_paths.append(barplot_path)
80
+
81
+ # 3) Simple Logistic Regression
82
+ if "Career" in df.columns and len(numeric_cols) > 0:
83
+ le = LabelEncoder()
84
+ df["Career_encoded"] = le.fit_transform(df["Career"])
85
+ X = df[numeric_cols].fillna(0)
86
+ y = df["Career_encoded"]
87
+ if len(np.unique(y)) > 1:
88
+ X_train, X_test, y_train, y_test = train_test_split(
89
+ X, y, test_size=0.2, random_state=42
90
+ )
91
+ model = LogisticRegression(max_iter=1000)
92
+ model.fit(X_train, y_train)
93
+ score = model.score(X_test, y_test)
94
+ accuracy_info = f"Logistic Regression accuracy: {score:.2f}"
95
  else:
96
+ accuracy_info = "Only one category in 'Career'; no classification performed."
97
+ else:
98
+ accuracy_info = "No 'Career' column or insufficient numeric columns for classification."
99
+
100
+ return output_paths, accuracy_info
101
+
102
+
103
+ ##############################################################################
104
+ # Main Chat/Analysis Function
105
+ ##############################################################################
106
+ def handle_chat(user_message, df, chat_history, api_key):
107
+ """
108
+ - If df is None, prompt user to upload a CSV.
109
+ - Else, do local analysis and optionally call GPT-4o-mini for suggestions.
110
+ - Update the chat_history with role='user' or role='assistant' messages.
111
+ - Return new chat_history in 'messages' format for the Gradio Chatbot (type='messages').
112
+ """
113
+ if df is None:
114
+ chat_history.append({"role": "assistant", "content": "Please upload a CSV first."})
115
+ return chat_history
116
+
117
+ # Summarize data
118
+ numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
119
+ cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
120
+ summary = (
121
+ f"Rows: {df.shape[0]}, Columns: {df.shape[1]}\n"
122
+ f"Numeric: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
123
+ f"Categorical: {', '.join(cat_cols) if cat_cols else 'None'}"
124
+ )
125
+
126
+ # Always show user message in chat
127
+ chat_history.append({"role": "user", "content": user_message})
128
+
129
+ # Possibly call GPT-4o-mini for suggestions
130
+ gpt_reply = ""
131
+ if api_key:
132
+ prompt = f"Data Summary:\n{summary}\nUser Query: {user_message}"
133
+ gpt_reply = call_gpt4o_mini(api_key, prompt)
134
+
135
+ # Build the reply text (local summary + LLM suggestions)
136
+ reply_text = f"**Data Summary**:\n{summary}"
137
+ if gpt_reply:
138
+ reply_text += f"\n\n**GPT-4o-mini**: {gpt_reply}"
139
+
140
+ # Check if user wants extended analysis
141
+ triggers = ["sample analysis", "extended analysis", "advanced analysis", "run analysis", "visualize", "plot"]
142
+ if any(t in user_message.lower() for t in triggers):
143
+ # Perform extended analysis
144
+ image_paths, info = extended_analysis(df)
145
+ if info:
146
+ reply_text += f"\n\n**Analysis Info**: {info}"
147
+ # Add images to chat
148
+ chat_history.append({"role": "assistant", "content": reply_text})
149
+ # Return images as separate chat items
150
+ for path in image_paths:
151
+ chat_history.append({"role": "assistant", "content": None, "image": path})
152
+ return chat_history
153
+
154
+ # If no extended analysis triggered, just add the text
155
+ chat_history.append({"role": "assistant", "content": reply_text})
156
+ return chat_history
157
+
158
+
159
+ ##############################################################################
160
+ # Gradio Interface
161
+ ##############################################################################
162
+ def create_demo():
163
+ with gr.Blocks() as demo:
164
+ # State: holds the DataFrame and the chat messages
165
+ df_state = gr.State(None)
166
+ chat_state = gr.State([]) # store messages as list of dicts: [{"role": "...", "content": "..."}]
167
+
168
+ gr.Markdown("## GPT-4o-mini Data Analysis Assistant (Chat)")
169
+ gr.Markdown(
170
+ """
171
+ 1. Enter your GPT-4o-mini token (`sk-...`) if you want AI suggestions.
172
+ 2. Upload a CSV file.
173
+ 3. Ask questions or request "sample analysis", "visualize", etc.
174
+ 4. Images are displayed in the chat when relevant.
175
+ """
176
  )
177
 
178
+ api_key_box = gr.Textbox(label="GPT-4o-mini Token (sk-...)", placeholder="Optional: sk-xxxx")
179
+ file_input = gr.File(label="Upload CSV", file_types=[".csv"])
 
180
 
181
+ # Chatbot in "messages" format to fix the deprecation warning
182
+ chatbot = gr.Chatbot(label="Chat Output", type="messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ user_message = gr.Textbox(label="Your Message", placeholder="Ask about your data...")
 
 
 
185
 
186
+ def upload_csv(file):
187
+ """
188
+ On file upload, load the DataFrame into df_state and reset the chat if needed.
189
+ """
190
+ if file is None:
191
+ return None
192
+ df = pd.read_csv(file.name)
193
+ return df
194
+
195
+ file_input.change(fn=upload_csv, inputs=file_input, outputs=df_state)
196
+
197
+ def on_user_message(message, df, chat_history, api_key):
198
+ """
199
+ Called when user sends a message. Handle chat + analysis. Return new chat messages.
200
+ """
201
+ if not message.strip():
202
+ return chat_history # ignore empty
203
+ updated_history = handle_chat(message, df, chat_history, api_key)
204
+ return updated_history
205
+
206
+ user_message.submit(
207
+ fn=on_user_message,
208
+ inputs=[user_message, df_state, chat_state, api_key_box],
209
+ outputs=chat_state
210
+ ).then(
211
+ # After updating chat_state, reflect it in the chatbot
212
+ fn=lambda messages: messages,
213
+ inputs=chat_state,
214
+ outputs=chatbot
215
+ ).then(
216
+ fn=lambda: "",
217
+ outputs=user_message
218
+ )
 
 
 
 
219
 
220
+ # Button to send message
221
+ send_btn = gr.Button("Send")
222
  send_btn.click(
223
+ fn=on_user_message,
224
+ inputs=[user_message, df_state, chat_state, api_key_box],
225
+ outputs=chat_state
226
+ ).then(
227
+ fn=lambda messages: messages,
228
+ inputs=chat_state,
229
+ outputs=chatbot
230
+ ).then(
231
+ fn=lambda: "",
232
+ outputs=user_message
233
+ )
234
 
235
+ # Clear chat button
236
+ clear_btn = gr.Button("Clear Chat")
237
+ def clear_chat():
238
+ return [], []
239
  clear_btn.click(
240
+ fn=clear_chat,
241
+ inputs=[],
242
+ outputs=[chat_state, chatbot]
243
  )
244
 
245
+ return demo
246
+
247
+ demo = create_demo()
248
 
249
  if __name__ == "__main__":
250
+ demo.launch(share=True)