Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import subprocess
|
4 |
+
import requests
|
5 |
+
import gradio as gr
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import seaborn as sns
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from sklearn.model_selection import train_test_split
|
12 |
+
from sklearn.linear_model import LogisticRegression
|
13 |
+
from sklearn.preprocessing import LabelEncoder
|
14 |
+
|
15 |
+
# --------------------------------------------------------------------------------
|
16 |
+
# OPTIONAL: dynamic installation for rarely used packages not in requirements.txt
|
17 |
+
# --------------------------------------------------------------------------------
|
18 |
+
|
19 |
+
def install_library(library):
|
20 |
+
"""
|
21 |
+
Install a library using pip.
|
22 |
+
Useful for rarely used packages NOT in requirements.txt.
|
23 |
+
"""
|
24 |
+
try:
|
25 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", library])
|
26 |
+
return f"Successfully installed {library}."
|
27 |
+
except Exception as e:
|
28 |
+
return f"Error installing {library}: {str(e)}"
|
29 |
+
|
30 |
+
def dynamic_import(library, alias=None):
|
31 |
+
"""
|
32 |
+
Dynamically import a library. If not found, try to install it, then import again.
|
33 |
+
"""
|
34 |
+
try:
|
35 |
+
if alias:
|
36 |
+
globals()[alias] = __import__(library)
|
37 |
+
else:
|
38 |
+
globals()[library] = __import__(library)
|
39 |
+
except ImportError:
|
40 |
+
install_msg = install_library(library)
|
41 |
+
print(install_msg)
|
42 |
+
globals()[library] = __import__(library)
|
43 |
+
|
44 |
+
|
45 |
+
# --------------------------------------------------------------------------------
|
46 |
+
# LLM CALLS: GPT-4o-mini, OpenAI, DeepSeek, Gemini
|
47 |
+
# --------------------------------------------------------------------------------
|
48 |
+
import openai
|
49 |
+
from huggingface_hub import InferenceClient
|
50 |
+
|
51 |
+
def call_gpt4o_mini(api_key, user_prompt):
|
52 |
+
"""
|
53 |
+
Calls a GPT-4o-mini model hosted on Hugging Face.
|
54 |
+
Replace 'someUser/gpt-4o-mini' with your actual model repo.
|
55 |
+
"""
|
56 |
+
if not api_key:
|
57 |
+
return "No Hugging Face API key provided. Cannot call GPT-4o-mini."
|
58 |
+
|
59 |
+
try:
|
60 |
+
client = InferenceClient(
|
61 |
+
repo_id="someUser/gpt-4o-mini", # <--- Replace with your real GPT-4o-mini repo
|
62 |
+
token=api_key
|
63 |
+
)
|
64 |
+
# We use text_generation endpoint; adapt if your model differs
|
65 |
+
response = client.text_generation(user_prompt, max_new_tokens=128)
|
66 |
+
# 'response' can be a string or dict depending on the endpoint. Assume it's a string:
|
67 |
+
return response
|
68 |
+
except Exception as e:
|
69 |
+
return f"Error calling GPT-4o-mini: {str(e)}"
|
70 |
+
|
71 |
+
def call_openai(api_key, user_prompt):
|
72 |
+
"""Calls OpenAI's API (example usage)."""
|
73 |
+
openai.api_key = api_key
|
74 |
+
try:
|
75 |
+
response = openai.Completion.create(
|
76 |
+
model="text-davinci-003",
|
77 |
+
prompt=user_prompt,
|
78 |
+
max_tokens=128
|
79 |
+
)
|
80 |
+
return response["choices"][0]["text"].strip()
|
81 |
+
except Exception as e:
|
82 |
+
return f"OpenAI Error: {str(e)}"
|
83 |
+
|
84 |
+
def call_deepseek(api_key, user_prompt):
|
85 |
+
"""
|
86 |
+
Hypothetical function to call a DeepSeek API endpoint.
|
87 |
+
Replace with real DeepSeek logic as needed.
|
88 |
+
"""
|
89 |
+
try:
|
90 |
+
headers = {
|
91 |
+
"Content-Type": "application/json",
|
92 |
+
"Authorization": f"Bearer {api_key}"
|
93 |
+
}
|
94 |
+
payload = {
|
95 |
+
"prompt": user_prompt,
|
96 |
+
"max_tokens": 128
|
97 |
+
}
|
98 |
+
# Example POST; adapt to the real DeepSeek endpoint
|
99 |
+
response = requests.post(
|
100 |
+
"https://api.deepseek.ai/v1/chat",
|
101 |
+
json=payload,
|
102 |
+
headers=headers
|
103 |
+
)
|
104 |
+
response.raise_for_status()
|
105 |
+
data = response.json()
|
106 |
+
return data["choices"][0]["text"].strip()
|
107 |
+
except Exception as e:
|
108 |
+
return f"DeepSeek Error: {str(e)}"
|
109 |
+
|
110 |
+
def call_gemini(api_key, user_prompt):
|
111 |
+
"""
|
112 |
+
Hypothetical function for Gemini LLM.
|
113 |
+
Replace with real Gemini logic.
|
114 |
+
"""
|
115 |
+
return "(Gemini usage not yet implemented; placeholder)"
|
116 |
+
|
117 |
+
def call_llm(api_provider, api_key, user_prompt):
|
118 |
+
"""Routes calls to the correct LLM provider."""
|
119 |
+
if not api_key:
|
120 |
+
return "No API key provided. Using GPT-4o-mini default is not possible without HF key." if api_provider.lower() == "gpt-4o-mini" else "No API key provided."
|
121 |
+
|
122 |
+
provider_lower = api_provider.lower()
|
123 |
+
if provider_lower == "gpt-4o-mini":
|
124 |
+
return call_gpt4o_mini(api_key, user_prompt)
|
125 |
+
elif provider_lower == "openai":
|
126 |
+
return call_openai(api_key, user_prompt)
|
127 |
+
elif provider_lower == "deepseek":
|
128 |
+
return call_deepseek(api_key, user_prompt)
|
129 |
+
elif provider_lower == "gemini":
|
130 |
+
return call_gemini(api_key, user_prompt)
|
131 |
+
else:
|
132 |
+
return f"Unknown provider: {api_provider}. Please choose GPT-4o-mini, OpenAI, DeepSeek, or Gemini."
|
133 |
+
|
134 |
+
# --------------------------------------------------------------------------------
|
135 |
+
# ADVANCED DATA ANALYSIS (extended_analysis)
|
136 |
+
# --------------------------------------------------------------------------------
|
137 |
+
def extended_analysis(df):
|
138 |
+
"""
|
139 |
+
Sample advanced analysis:
|
140 |
+
1. Correlation heatmap for numeric columns
|
141 |
+
2. Bar plot of 'Career' (if present)
|
142 |
+
3. Simple logistic regression classification if 'Career' is suitable
|
143 |
+
"""
|
144 |
+
output_paths = []
|
145 |
+
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
146 |
+
|
147 |
+
cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
|
148 |
+
|
149 |
+
# 1) Correlation Heatmap
|
150 |
+
if len(numeric_cols) > 1:
|
151 |
+
corr = df[numeric_cols].corr()
|
152 |
+
plt.figure(figsize=(8, 6))
|
153 |
+
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
|
154 |
+
plt.title("Correlation Heatmap")
|
155 |
+
heatmap_path = "heatmap.png"
|
156 |
+
plt.savefig(heatmap_path)
|
157 |
+
plt.close()
|
158 |
+
output_paths.append(heatmap_path)
|
159 |
+
|
160 |
+
# 2) Bar Plot of 'Career' if present
|
161 |
+
if "Career" in df.columns:
|
162 |
+
plt.figure(figsize=(8, 5))
|
163 |
+
df["Career"].value_counts().plot(kind="bar")
|
164 |
+
plt.title("Count of Each Career")
|
165 |
+
plt.xlabel("Career")
|
166 |
+
plt.ylabel("Count")
|
167 |
+
barplot_path = "barplot_career.png"
|
168 |
+
plt.savefig(barplot_path)
|
169 |
+
plt.close()
|
170 |
+
output_paths.append(barplot_path)
|
171 |
+
|
172 |
+
# 3) Simple Logistic Regression if 'Career' exists with multiple categories
|
173 |
+
if "Career" in df.columns and len(numeric_cols) > 0:
|
174 |
+
le = LabelEncoder()
|
175 |
+
df["Career_encoded"] = le.fit_transform(df["Career"])
|
176 |
+
X = df[numeric_cols].fillna(0)
|
177 |
+
y = df["Career_encoded"]
|
178 |
+
if len(np.unique(y)) > 1:
|
179 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
180 |
+
model = LogisticRegression(max_iter=1000)
|
181 |
+
model.fit(X_train, y_train)
|
182 |
+
score = model.score(X_test, y_test)
|
183 |
+
accuracy_info = f"Logistic Regression accuracy on test set: {score:.2f}"
|
184 |
+
else:
|
185 |
+
accuracy_info = "Career column has only one class; no classification performed."
|
186 |
+
else:
|
187 |
+
accuracy_info = "No 'Career' column or insufficient numeric data for classification."
|
188 |
+
|
189 |
+
return output_paths, accuracy_info
|
190 |
+
|
191 |
+
# --------------------------------------------------------------------------------
|
192 |
+
# MAIN ANALYSIS AND VISUALIZATION FUNCTION
|
193 |
+
# --------------------------------------------------------------------------------
|
194 |
+
def analyze_and_visualize(
|
195 |
+
file,
|
196 |
+
message,
|
197 |
+
history,
|
198 |
+
api_provider,
|
199 |
+
api_key
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
Loads CSV, gives a summary, calls LLM for suggestions if an API key is provided,
|
203 |
+
does extended analysis if user requests ("sample analysis", "extended analysis", etc.),
|
204 |
+
and returns results/plots in the chatbot.
|
205 |
+
"""
|
206 |
+
try:
|
207 |
+
# Load CSV
|
208 |
+
df = pd.read_csv(file.name)
|
209 |
+
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
210 |
+
categorical_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
|
211 |
+
|
212 |
+
# Basic info
|
213 |
+
summary = (
|
214 |
+
f"**File**: {file.name}\n"
|
215 |
+
f"**Shape**: {df.shape[0]} rows, {df.shape[1]} columns\n"
|
216 |
+
f"**Numerical Columns**: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
|
217 |
+
f"**Categorical Columns**: {', '.join(categorical_cols) if categorical_cols else 'None'}\n"
|
218 |
+
)
|
219 |
+
|
220 |
+
# LLM suggestions
|
221 |
+
llm_suggestions = ""
|
222 |
+
if api_key:
|
223 |
+
user_prompt = (
|
224 |
+
f"Data Summary:\n{summary}\n\n"
|
225 |
+
f"User question or request: {message}\n"
|
226 |
+
f"Suggest advanced data analysis or steps if relevant."
|
227 |
+
)
|
228 |
+
llm_response = call_llm(api_provider, api_key, user_prompt)
|
229 |
+
llm_suggestions = f"\n**LLM Suggestions**:\n{llm_response}\n"
|
230 |
+
else:
|
231 |
+
llm_suggestions = "\n(No LLM suggestions because no API key provided.)\n"
|
232 |
+
|
233 |
+
# Always produce example histogram if there's at least one numeric column
|
234 |
+
hist_path = None
|
235 |
+
if numeric_cols:
|
236 |
+
plt.figure(figsize=(6, 4))
|
237 |
+
sns.histplot(df[numeric_cols[0]], kde=True)
|
238 |
+
plt.title(f"Distribution of '{numeric_cols[0]}'")
|
239 |
+
plt.tight_layout()
|
240 |
+
hist_path = "temp_plot.png"
|
241 |
+
plt.savefig(hist_path)
|
242 |
+
plt.close()
|
243 |
+
|
244 |
+
# Check if the user wants extended analysis
|
245 |
+
trigger_phrases = ["sample analysis", "extended analysis", "advanced analysis", "run analysis"]
|
246 |
+
analysis_paths = []
|
247 |
+
accuracy_info = ""
|
248 |
+
if any(phrase in message.lower() for phrase in trigger_phrases):
|
249 |
+
analysis_paths, accuracy_info = extended_analysis(df)
|
250 |
+
|
251 |
+
# Build final response text
|
252 |
+
response_text = summary + llm_suggestions
|
253 |
+
if accuracy_info:
|
254 |
+
response_text += f"\n**ML Model Info**: {accuracy_info}\n"
|
255 |
+
|
256 |
+
# Construct the final chatbot content
|
257 |
+
chat_content = [(message, response_text)]
|
258 |
+
if hist_path:
|
259 |
+
chat_content.append((None, (hist_path,)))
|
260 |
+
for path in analysis_paths:
|
261 |
+
chat_content.append((None, (path,)))
|
262 |
+
|
263 |
+
return history + chat_content
|
264 |
+
|
265 |
+
except Exception as e:
|
266 |
+
return history + [(message, f"Error: {str(e)}")]
|
267 |
+
|
268 |
+
# --------------------------------------------------------------------------------
|
269 |
+
# CREATING THE GRADIO APP
|
270 |
+
# --------------------------------------------------------------------------------
|
271 |
+
def create_demo():
|
272 |
+
with gr.Blocks() as demo:
|
273 |
+
gr.Markdown("# 🤖 GPT-4o-mini (Default) + Multi-Provider AI Data Analysis Assistant")
|
274 |
+
gr.Markdown(
|
275 |
+
"""
|
276 |
+
**Features**:
|
277 |
+
- Default LLM: GPT-4o-mini on Hugging Face (requires HF API key).
|
278 |
+
- Other providers: **OpenAI**, **DeepSeek**, **Gemini** (enter their respective API keys).
|
279 |
+
- Upload CSV for data summary & histograms.
|
280 |
+
- Type "sample analysis" or "extended analysis" to trigger correlation heatmaps, bar plots, and a simple logistic regression.
|
281 |
+
"""
|
282 |
+
)
|
283 |
+
|
284 |
+
with gr.Row():
|
285 |
+
api_provider = gr.Dropdown(
|
286 |
+
choices=["GPT-4o-mini", "OpenAI", "DeepSeek", "Gemini"],
|
287 |
+
value="GPT-4o-mini", # default
|
288 |
+
label="LLM Provider",
|
289 |
+
)
|
290 |
+
api_key = gr.Textbox(
|
291 |
+
label="LLM API Key",
|
292 |
+
placeholder="Enter your Hugging Face/DeepSeek/OpenAI/Gemini API key here..."
|
293 |
+
)
|
294 |
+
|
295 |
+
file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
|
296 |
+
chatbot = gr.Chatbot(label="Analysis Output")
|
297 |
+
msg = gr.Textbox(
|
298 |
+
label="Message",
|
299 |
+
placeholder="Ask the AI or type 'sample analysis' for extended analysis..."
|
300 |
+
)
|
301 |
+
|
302 |
+
send_btn = gr.Button("Send")
|
303 |
+
reset_btn = gr.Button("Reset Chat")
|
304 |
+
|
305 |
+
def reset_chat():
|
306 |
+
return []
|
307 |
+
|
308 |
+
msg.submit(
|
309 |
+
fn=lambda f, m, h, p, k: analyze_and_visualize(f, m, h or [], p, k),
|
310 |
+
inputs=[file_input, msg, chatbot, api_provider, api_key],
|
311 |
+
outputs=[chatbot]
|
312 |
+
).then(lambda: "", None, [msg])
|
313 |
+
|
314 |
+
send_btn.click(
|
315 |
+
fn=lambda f, m, h, p, k: analyze_and_visualize(f, m, h or [], p, k),
|
316 |
+
inputs=[file_input, msg, chatbot, api_provider, api_key],
|
317 |
+
outputs=[chatbot]
|
318 |
+
).then(lambda: "", None, [msg])
|
319 |
+
|
320 |
+
reset_btn.click(fn=reset_chat, inputs=[], outputs=[chatbot])
|
321 |
+
|
322 |
+
demo.queue()
|
323 |
+
return demo
|
324 |
+
|
325 |
+
|
326 |
+
demo = create_demo()
|
327 |
+
|
328 |
+
if __name__ == "__main__":
|
329 |
+
demo.launch(share=True)
|