|
import os |
|
import requests |
|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import numpy as np |
|
|
|
from sklearn.model_selection import train_test_split |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.preprocessing import LabelEncoder |
|
|
|
|
|
|
|
|
|
def call_gpt4o_mini(api_key, user_prompt): |
|
""" |
|
Hypothetical call to GPT-4o-mini with an sk-... style token. |
|
Example endpoint: https://api.gpt4o-mini.com/v1/chat |
|
- Adjust JSON structure and keys to your actual service spec. |
|
""" |
|
if not api_key or not api_key.startswith("sk-"): |
|
return "Please provide a valid GPT-4o-mini token (sk-...)." |
|
|
|
url = "https://api.gpt4o-mini.com/v1/chat" |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json", |
|
} |
|
payload = { |
|
"prompt": user_prompt, |
|
"max_tokens": 128, |
|
"temperature": 0.7, |
|
} |
|
try: |
|
response = requests.post(url, json=payload, headers=headers, timeout=10) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
return data["choices"][0]["text"] |
|
except Exception as e: |
|
return f"Error calling GPT-4o-mini: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def extended_analysis(df): |
|
""" |
|
Does correlation heatmap, bar plot for 'Career', and logistic regression |
|
if 'Career' has multiple categories. Returns (list_of_image_paths, info_string). |
|
""" |
|
output_paths = [] |
|
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() |
|
|
|
|
|
if len(numeric_cols) > 1: |
|
corr = df[numeric_cols].corr() |
|
plt.figure(figsize=(8, 6)) |
|
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f") |
|
plt.title("Correlation Heatmap") |
|
heatmap_path = "heatmap.png" |
|
plt.savefig(heatmap_path) |
|
plt.close() |
|
output_paths.append(heatmap_path) |
|
|
|
|
|
if "Career" in df.columns: |
|
plt.figure(figsize=(8, 5)) |
|
career_counts = df["Career"].value_counts() |
|
sns.barplot(x=career_counts.index, y=career_counts.values) |
|
plt.title("Distribution of Careers") |
|
plt.xlabel("Career") |
|
plt.ylabel("Count") |
|
plt.xticks(rotation=45, ha="right") |
|
barplot_path = "career_distribution.png" |
|
plt.savefig(barplot_path) |
|
plt.close() |
|
output_paths.append(barplot_path) |
|
|
|
|
|
if "Career" in df.columns and len(numeric_cols) > 0: |
|
le = LabelEncoder() |
|
df["Career_encoded"] = le.fit_transform(df["Career"]) |
|
X = df[numeric_cols].fillna(0) |
|
y = df["Career_encoded"] |
|
if len(np.unique(y)) > 1: |
|
X_train, X_test, y_train, y_test = train_test_split( |
|
X, y, test_size=0.2, random_state=42 |
|
) |
|
model = LogisticRegression(max_iter=1000) |
|
model.fit(X_train, y_train) |
|
score = model.score(X_test, y_test) |
|
accuracy_info = f"Logistic Regression accuracy: {score:.2f}" |
|
else: |
|
accuracy_info = "Only one category in 'Career'; no classification performed." |
|
else: |
|
accuracy_info = "No 'Career' column or insufficient numeric columns for classification." |
|
|
|
return output_paths, accuracy_info |
|
|
|
|
|
|
|
|
|
|
|
def handle_chat(user_message, df, chat_history, api_key): |
|
""" |
|
- If df is None, prompt user to upload a CSV. |
|
- Else, do local analysis and optionally call GPT-4o-mini for suggestions. |
|
- Update the chat_history with role='user' or role='assistant' messages. |
|
- Return new chat_history in 'messages' format for the Gradio Chatbot (type='messages'). |
|
""" |
|
if df is None: |
|
chat_history.append({"role": "assistant", "content": "Please upload a CSV first."}) |
|
return chat_history |
|
|
|
|
|
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() |
|
cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist() |
|
summary = ( |
|
f"Rows: {df.shape[0]}, Columns: {df.shape[1]}\n" |
|
f"Numeric: {', '.join(numeric_cols) if numeric_cols else 'None'}\n" |
|
f"Categorical: {', '.join(cat_cols) if cat_cols else 'None'}" |
|
) |
|
|
|
|
|
chat_history.append({"role": "user", "content": user_message}) |
|
|
|
|
|
gpt_reply = "" |
|
if api_key: |
|
prompt = f"Data Summary:\n{summary}\nUser Query: {user_message}" |
|
gpt_reply = call_gpt4o_mini(api_key, prompt) |
|
|
|
|
|
reply_text = f"**Data Summary**:\n{summary}" |
|
if gpt_reply: |
|
reply_text += f"\n\n**GPT-4o-mini**: {gpt_reply}" |
|
|
|
|
|
triggers = ["sample analysis", "extended analysis", "advanced analysis", "run analysis", "visualize", "plot"] |
|
if any(t in user_message.lower() for t in triggers): |
|
|
|
image_paths, info = extended_analysis(df) |
|
if info: |
|
reply_text += f"\n\n**Analysis Info**: {info}" |
|
|
|
chat_history.append({"role": "assistant", "content": reply_text}) |
|
|
|
for path in image_paths: |
|
chat_history.append({"role": "assistant", "content": None, "image": path}) |
|
return chat_history |
|
|
|
|
|
chat_history.append({"role": "assistant", "content": reply_text}) |
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks() as demo: |
|
|
|
df_state = gr.State(None) |
|
chat_state = gr.State([]) |
|
|
|
gr.Markdown("## GPT-4o-mini Data Analysis Assistant (Chat)") |
|
gr.Markdown( |
|
""" |
|
1. Enter your GPT-4o-mini token (`sk-...`) if you want AI suggestions. |
|
2. Upload a CSV file. |
|
3. Ask questions or request "sample analysis", "visualize", etc. |
|
4. Images are displayed in the chat when relevant. |
|
""" |
|
) |
|
|
|
api_key_box = gr.Textbox(label="GPT-4o-mini Token (sk-...)", placeholder="Optional: sk-xxxx") |
|
file_input = gr.File(label="Upload CSV", file_types=[".csv"]) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Chat Output", type="messages") |
|
|
|
user_message = gr.Textbox(label="Your Message", placeholder="Ask about your data...") |
|
|
|
def upload_csv(file): |
|
""" |
|
On file upload, load the DataFrame into df_state and reset the chat if needed. |
|
""" |
|
if file is None: |
|
return None |
|
df = pd.read_csv(file.name) |
|
return df |
|
|
|
file_input.change(fn=upload_csv, inputs=file_input, outputs=df_state) |
|
|
|
def on_user_message(message, df, chat_history, api_key): |
|
""" |
|
Called when user sends a message. Handle chat + analysis. Return new chat messages. |
|
""" |
|
if not message.strip(): |
|
return chat_history |
|
updated_history = handle_chat(message, df, chat_history, api_key) |
|
return updated_history |
|
|
|
user_message.submit( |
|
fn=on_user_message, |
|
inputs=[user_message, df_state, chat_state, api_key_box], |
|
outputs=chat_state |
|
).then( |
|
|
|
fn=lambda messages: messages, |
|
inputs=chat_state, |
|
outputs=chatbot |
|
).then( |
|
fn=lambda: "", |
|
outputs=user_message |
|
) |
|
|
|
|
|
send_btn = gr.Button("Send") |
|
send_btn.click( |
|
fn=on_user_message, |
|
inputs=[user_message, df_state, chat_state, api_key_box], |
|
outputs=chat_state |
|
).then( |
|
fn=lambda messages: messages, |
|
inputs=chat_state, |
|
outputs=chatbot |
|
).then( |
|
fn=lambda: "", |
|
outputs=user_message |
|
) |
|
|
|
|
|
clear_btn = gr.Button("Clear Chat") |
|
def clear_chat(): |
|
return [], [] |
|
clear_btn.click( |
|
fn=clear_chat, |
|
inputs=[], |
|
outputs=[chat_state, chatbot] |
|
) |
|
|
|
return demo |
|
|
|
demo = create_demo() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|