jzou19950715 commited on
Commit
bcd9ccf
·
verified ·
1 Parent(s): 6a8cba0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import requests
5
+ import gradio as gr
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import numpy as np
10
+
11
+ from sklearn.model_selection import train_test_split
12
+ from sklearn.linear_model import LogisticRegression
13
+ from sklearn.preprocessing import LabelEncoder
14
+
15
+ # --------------------------------------------------------------------------------
16
+ # OPTIONAL: dynamic installation for rarely used packages not in requirements.txt
17
+ # --------------------------------------------------------------------------------
18
+
19
+ def install_library(library):
20
+ """
21
+ Install a library using pip.
22
+ Useful for rarely used packages NOT in requirements.txt.
23
+ """
24
+ try:
25
+ subprocess.check_call([sys.executable, "-m", "pip", "install", library])
26
+ return f"Successfully installed {library}."
27
+ except Exception as e:
28
+ return f"Error installing {library}: {str(e)}"
29
+
30
+ def dynamic_import(library, alias=None):
31
+ """
32
+ Dynamically import a library. If not found, try to install it, then import again.
33
+ """
34
+ try:
35
+ if alias:
36
+ globals()[alias] = __import__(library)
37
+ else:
38
+ globals()[library] = __import__(library)
39
+ except ImportError:
40
+ install_msg = install_library(library)
41
+ print(install_msg)
42
+ globals()[library] = __import__(library)
43
+
44
+
45
+ # --------------------------------------------------------------------------------
46
+ # LLM CALLS: GPT-4o-mini, OpenAI, DeepSeek, Gemini
47
+ # --------------------------------------------------------------------------------
48
+ import openai
49
+ from huggingface_hub import InferenceClient
50
+
51
+ def call_gpt4o_mini(api_key, user_prompt):
52
+ """
53
+ Calls a GPT-4o-mini model hosted on Hugging Face.
54
+ Replace 'someUser/gpt-4o-mini' with your actual model repo.
55
+ """
56
+ if not api_key:
57
+ return "No Hugging Face API key provided. Cannot call GPT-4o-mini."
58
+
59
+ try:
60
+ client = InferenceClient(
61
+ repo_id="someUser/gpt-4o-mini", # <--- Replace with your real GPT-4o-mini repo
62
+ token=api_key
63
+ )
64
+ # We use text_generation endpoint; adapt if your model differs
65
+ response = client.text_generation(user_prompt, max_new_tokens=128)
66
+ # 'response' can be a string or dict depending on the endpoint. Assume it's a string:
67
+ return response
68
+ except Exception as e:
69
+ return f"Error calling GPT-4o-mini: {str(e)}"
70
+
71
+ def call_openai(api_key, user_prompt):
72
+ """Calls OpenAI's API (example usage)."""
73
+ openai.api_key = api_key
74
+ try:
75
+ response = openai.Completion.create(
76
+ model="text-davinci-003",
77
+ prompt=user_prompt,
78
+ max_tokens=128
79
+ )
80
+ return response["choices"][0]["text"].strip()
81
+ except Exception as e:
82
+ return f"OpenAI Error: {str(e)}"
83
+
84
+ def call_deepseek(api_key, user_prompt):
85
+ """
86
+ Hypothetical function to call a DeepSeek API endpoint.
87
+ Replace with real DeepSeek logic as needed.
88
+ """
89
+ try:
90
+ headers = {
91
+ "Content-Type": "application/json",
92
+ "Authorization": f"Bearer {api_key}"
93
+ }
94
+ payload = {
95
+ "prompt": user_prompt,
96
+ "max_tokens": 128
97
+ }
98
+ # Example POST; adapt to the real DeepSeek endpoint
99
+ response = requests.post(
100
+ "https://api.deepseek.ai/v1/chat",
101
+ json=payload,
102
+ headers=headers
103
+ )
104
+ response.raise_for_status()
105
+ data = response.json()
106
+ return data["choices"][0]["text"].strip()
107
+ except Exception as e:
108
+ return f"DeepSeek Error: {str(e)}"
109
+
110
+ def call_gemini(api_key, user_prompt):
111
+ """
112
+ Hypothetical function for Gemini LLM.
113
+ Replace with real Gemini logic.
114
+ """
115
+ return "(Gemini usage not yet implemented; placeholder)"
116
+
117
+ def call_llm(api_provider, api_key, user_prompt):
118
+ """Routes calls to the correct LLM provider."""
119
+ if not api_key:
120
+ return "No API key provided. Using GPT-4o-mini default is not possible without HF key." if api_provider.lower() == "gpt-4o-mini" else "No API key provided."
121
+
122
+ provider_lower = api_provider.lower()
123
+ if provider_lower == "gpt-4o-mini":
124
+ return call_gpt4o_mini(api_key, user_prompt)
125
+ elif provider_lower == "openai":
126
+ return call_openai(api_key, user_prompt)
127
+ elif provider_lower == "deepseek":
128
+ return call_deepseek(api_key, user_prompt)
129
+ elif provider_lower == "gemini":
130
+ return call_gemini(api_key, user_prompt)
131
+ else:
132
+ return f"Unknown provider: {api_provider}. Please choose GPT-4o-mini, OpenAI, DeepSeek, or Gemini."
133
+
134
+ # --------------------------------------------------------------------------------
135
+ # ADVANCED DATA ANALYSIS (extended_analysis)
136
+ # --------------------------------------------------------------------------------
137
+ def extended_analysis(df):
138
+ """
139
+ Sample advanced analysis:
140
+ 1. Correlation heatmap for numeric columns
141
+ 2. Bar plot of 'Career' (if present)
142
+ 3. Simple logistic regression classification if 'Career' is suitable
143
+ """
144
+ output_paths = []
145
+ numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
146
+
147
+ cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
148
+
149
+ # 1) Correlation Heatmap
150
+ if len(numeric_cols) > 1:
151
+ corr = df[numeric_cols].corr()
152
+ plt.figure(figsize=(8, 6))
153
+ sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
154
+ plt.title("Correlation Heatmap")
155
+ heatmap_path = "heatmap.png"
156
+ plt.savefig(heatmap_path)
157
+ plt.close()
158
+ output_paths.append(heatmap_path)
159
+
160
+ # 2) Bar Plot of 'Career' if present
161
+ if "Career" in df.columns:
162
+ plt.figure(figsize=(8, 5))
163
+ df["Career"].value_counts().plot(kind="bar")
164
+ plt.title("Count of Each Career")
165
+ plt.xlabel("Career")
166
+ plt.ylabel("Count")
167
+ barplot_path = "barplot_career.png"
168
+ plt.savefig(barplot_path)
169
+ plt.close()
170
+ output_paths.append(barplot_path)
171
+
172
+ # 3) Simple Logistic Regression if 'Career' exists with multiple categories
173
+ if "Career" in df.columns and len(numeric_cols) > 0:
174
+ le = LabelEncoder()
175
+ df["Career_encoded"] = le.fit_transform(df["Career"])
176
+ X = df[numeric_cols].fillna(0)
177
+ y = df["Career_encoded"]
178
+ if len(np.unique(y)) > 1:
179
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
180
+ model = LogisticRegression(max_iter=1000)
181
+ model.fit(X_train, y_train)
182
+ score = model.score(X_test, y_test)
183
+ accuracy_info = f"Logistic Regression accuracy on test set: {score:.2f}"
184
+ else:
185
+ accuracy_info = "Career column has only one class; no classification performed."
186
+ else:
187
+ accuracy_info = "No 'Career' column or insufficient numeric data for classification."
188
+
189
+ return output_paths, accuracy_info
190
+
191
+ # --------------------------------------------------------------------------------
192
+ # MAIN ANALYSIS AND VISUALIZATION FUNCTION
193
+ # --------------------------------------------------------------------------------
194
+ def analyze_and_visualize(
195
+ file,
196
+ message,
197
+ history,
198
+ api_provider,
199
+ api_key
200
+ ):
201
+ """
202
+ Loads CSV, gives a summary, calls LLM for suggestions if an API key is provided,
203
+ does extended analysis if user requests ("sample analysis", "extended analysis", etc.),
204
+ and returns results/plots in the chatbot.
205
+ """
206
+ try:
207
+ # Load CSV
208
+ df = pd.read_csv(file.name)
209
+ numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
210
+ categorical_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
211
+
212
+ # Basic info
213
+ summary = (
214
+ f"**File**: {file.name}\n"
215
+ f"**Shape**: {df.shape[0]} rows, {df.shape[1]} columns\n"
216
+ f"**Numerical Columns**: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
217
+ f"**Categorical Columns**: {', '.join(categorical_cols) if categorical_cols else 'None'}\n"
218
+ )
219
+
220
+ # LLM suggestions
221
+ llm_suggestions = ""
222
+ if api_key:
223
+ user_prompt = (
224
+ f"Data Summary:\n{summary}\n\n"
225
+ f"User question or request: {message}\n"
226
+ f"Suggest advanced data analysis or steps if relevant."
227
+ )
228
+ llm_response = call_llm(api_provider, api_key, user_prompt)
229
+ llm_suggestions = f"\n**LLM Suggestions**:\n{llm_response}\n"
230
+ else:
231
+ llm_suggestions = "\n(No LLM suggestions because no API key provided.)\n"
232
+
233
+ # Always produce example histogram if there's at least one numeric column
234
+ hist_path = None
235
+ if numeric_cols:
236
+ plt.figure(figsize=(6, 4))
237
+ sns.histplot(df[numeric_cols[0]], kde=True)
238
+ plt.title(f"Distribution of '{numeric_cols[0]}'")
239
+ plt.tight_layout()
240
+ hist_path = "temp_plot.png"
241
+ plt.savefig(hist_path)
242
+ plt.close()
243
+
244
+ # Check if the user wants extended analysis
245
+ trigger_phrases = ["sample analysis", "extended analysis", "advanced analysis", "run analysis"]
246
+ analysis_paths = []
247
+ accuracy_info = ""
248
+ if any(phrase in message.lower() for phrase in trigger_phrases):
249
+ analysis_paths, accuracy_info = extended_analysis(df)
250
+
251
+ # Build final response text
252
+ response_text = summary + llm_suggestions
253
+ if accuracy_info:
254
+ response_text += f"\n**ML Model Info**: {accuracy_info}\n"
255
+
256
+ # Construct the final chatbot content
257
+ chat_content = [(message, response_text)]
258
+ if hist_path:
259
+ chat_content.append((None, (hist_path,)))
260
+ for path in analysis_paths:
261
+ chat_content.append((None, (path,)))
262
+
263
+ return history + chat_content
264
+
265
+ except Exception as e:
266
+ return history + [(message, f"Error: {str(e)}")]
267
+
268
+ # --------------------------------------------------------------------------------
269
+ # CREATING THE GRADIO APP
270
+ # --------------------------------------------------------------------------------
271
+ def create_demo():
272
+ with gr.Blocks() as demo:
273
+ gr.Markdown("# 🤖 GPT-4o-mini (Default) + Multi-Provider AI Data Analysis Assistant")
274
+ gr.Markdown(
275
+ """
276
+ **Features**:
277
+ - Default LLM: GPT-4o-mini on Hugging Face (requires HF API key).
278
+ - Other providers: **OpenAI**, **DeepSeek**, **Gemini** (enter their respective API keys).
279
+ - Upload CSV for data summary & histograms.
280
+ - Type "sample analysis" or "extended analysis" to trigger correlation heatmaps, bar plots, and a simple logistic regression.
281
+ """
282
+ )
283
+
284
+ with gr.Row():
285
+ api_provider = gr.Dropdown(
286
+ choices=["GPT-4o-mini", "OpenAI", "DeepSeek", "Gemini"],
287
+ value="GPT-4o-mini", # default
288
+ label="LLM Provider",
289
+ )
290
+ api_key = gr.Textbox(
291
+ label="LLM API Key",
292
+ placeholder="Enter your Hugging Face/DeepSeek/OpenAI/Gemini API key here..."
293
+ )
294
+
295
+ file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
296
+ chatbot = gr.Chatbot(label="Analysis Output")
297
+ msg = gr.Textbox(
298
+ label="Message",
299
+ placeholder="Ask the AI or type 'sample analysis' for extended analysis..."
300
+ )
301
+
302
+ send_btn = gr.Button("Send")
303
+ reset_btn = gr.Button("Reset Chat")
304
+
305
+ def reset_chat():
306
+ return []
307
+
308
+ msg.submit(
309
+ fn=lambda f, m, h, p, k: analyze_and_visualize(f, m, h or [], p, k),
310
+ inputs=[file_input, msg, chatbot, api_provider, api_key],
311
+ outputs=[chatbot]
312
+ ).then(lambda: "", None, [msg])
313
+
314
+ send_btn.click(
315
+ fn=lambda f, m, h, p, k: analyze_and_visualize(f, m, h or [], p, k),
316
+ inputs=[file_input, msg, chatbot, api_provider, api_key],
317
+ outputs=[chatbot]
318
+ ).then(lambda: "", None, [msg])
319
+
320
+ reset_btn.click(fn=reset_chat, inputs=[], outputs=[chatbot])
321
+
322
+ demo.queue()
323
+ return demo
324
+
325
+
326
+ demo = create_demo()
327
+
328
+ if __name__ == "__main__":
329
+ demo.launch(share=True)