jzou19950715's picture
Update app.py
4722ac6 verified
raw
history blame
9.43 kB
import os
import requests
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
##############################################################################
# GPT-4o-mini Placeholder - Adjust for your real endpoint & JSON
##############################################################################
def call_gpt4o_mini(api_key, user_prompt):
"""
Hypothetical call to GPT-4o-mini with an sk-... style token.
Example endpoint: https://api.gpt4o-mini.com/v1/chat
- Adjust JSON structure and keys to your actual service spec.
"""
if not api_key or not api_key.startswith("sk-"):
return "Please provide a valid GPT-4o-mini token (sk-...)."
url = "https://api.gpt4o-mini.com/v1/chat" # <--- Replace with real endpoint
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"prompt": user_prompt,
"max_tokens": 128, # limit tokens for cost
"temperature": 0.7,
}
try:
response = requests.post(url, json=payload, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
# Suppose the text is in data["choices"][0]["text"] (adjust if needed)
return data["choices"][0]["text"]
except Exception as e:
return f"Error calling GPT-4o-mini: {str(e)}"
##############################################################################
# Local Data Analysis
##############################################################################
def extended_analysis(df):
"""
Does correlation heatmap, bar plot for 'Career', and logistic regression
if 'Career' has multiple categories. Returns (list_of_image_paths, info_string).
"""
output_paths = []
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
# 1) Correlation Heatmap
if len(numeric_cols) > 1:
corr = df[numeric_cols].corr()
plt.figure(figsize=(8, 6))
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Heatmap")
heatmap_path = "heatmap.png"
plt.savefig(heatmap_path)
plt.close()
output_paths.append(heatmap_path)
# 2) Bar Plot for 'Career'
if "Career" in df.columns:
plt.figure(figsize=(8, 5))
career_counts = df["Career"].value_counts()
sns.barplot(x=career_counts.index, y=career_counts.values)
plt.title("Distribution of Careers")
plt.xlabel("Career")
plt.ylabel("Count")
plt.xticks(rotation=45, ha="right")
barplot_path = "career_distribution.png"
plt.savefig(barplot_path)
plt.close()
output_paths.append(barplot_path)
# 3) Simple Logistic Regression
if "Career" in df.columns and len(numeric_cols) > 0:
le = LabelEncoder()
df["Career_encoded"] = le.fit_transform(df["Career"])
X = df[numeric_cols].fillna(0)
y = df["Career_encoded"]
if len(np.unique(y)) > 1:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
accuracy_info = f"Logistic Regression accuracy: {score:.2f}"
else:
accuracy_info = "Only one category in 'Career'; no classification performed."
else:
accuracy_info = "No 'Career' column or insufficient numeric columns for classification."
return output_paths, accuracy_info
##############################################################################
# Main Chat/Analysis Function
##############################################################################
def handle_chat(user_message, df, chat_history, api_key):
"""
- If df is None, prompt user to upload a CSV.
- Else, do local analysis and optionally call GPT-4o-mini for suggestions.
- Update the chat_history with role='user' or role='assistant' messages.
- Return new chat_history in 'messages' format for the Gradio Chatbot (type='messages').
"""
if df is None:
chat_history.append({"role": "assistant", "content": "Please upload a CSV first."})
return chat_history
# Summarize data
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
summary = (
f"Rows: {df.shape[0]}, Columns: {df.shape[1]}\n"
f"Numeric: {', '.join(numeric_cols) if numeric_cols else 'None'}\n"
f"Categorical: {', '.join(cat_cols) if cat_cols else 'None'}"
)
# Always show user message in chat
chat_history.append({"role": "user", "content": user_message})
# Possibly call GPT-4o-mini for suggestions
gpt_reply = ""
if api_key:
prompt = f"Data Summary:\n{summary}\nUser Query: {user_message}"
gpt_reply = call_gpt4o_mini(api_key, prompt)
# Build the reply text (local summary + LLM suggestions)
reply_text = f"**Data Summary**:\n{summary}"
if gpt_reply:
reply_text += f"\n\n**GPT-4o-mini**: {gpt_reply}"
# Check if user wants extended analysis
triggers = ["sample analysis", "extended analysis", "advanced analysis", "run analysis", "visualize", "plot"]
if any(t in user_message.lower() for t in triggers):
# Perform extended analysis
image_paths, info = extended_analysis(df)
if info:
reply_text += f"\n\n**Analysis Info**: {info}"
# Add images to chat
chat_history.append({"role": "assistant", "content": reply_text})
# Return images as separate chat items
for path in image_paths:
chat_history.append({"role": "assistant", "content": None, "image": path})
return chat_history
# If no extended analysis triggered, just add the text
chat_history.append({"role": "assistant", "content": reply_text})
return chat_history
##############################################################################
# Gradio Interface
##############################################################################
def create_demo():
with gr.Blocks() as demo:
# State: holds the DataFrame and the chat messages
df_state = gr.State(None)
chat_state = gr.State([]) # store messages as list of dicts: [{"role": "...", "content": "..."}]
gr.Markdown("## GPT-4o-mini Data Analysis Assistant (Chat)")
gr.Markdown(
"""
1. Enter your GPT-4o-mini token (`sk-...`) if you want AI suggestions.
2. Upload a CSV file.
3. Ask questions or request "sample analysis", "visualize", etc.
4. Images are displayed in the chat when relevant.
"""
)
api_key_box = gr.Textbox(label="GPT-4o-mini Token (sk-...)", placeholder="Optional: sk-xxxx")
file_input = gr.File(label="Upload CSV", file_types=[".csv"])
# Chatbot in "messages" format to fix the deprecation warning
chatbot = gr.Chatbot(label="Chat Output", type="messages")
user_message = gr.Textbox(label="Your Message", placeholder="Ask about your data...")
def upload_csv(file):
"""
On file upload, load the DataFrame into df_state and reset the chat if needed.
"""
if file is None:
return None
df = pd.read_csv(file.name)
return df
file_input.change(fn=upload_csv, inputs=file_input, outputs=df_state)
def on_user_message(message, df, chat_history, api_key):
"""
Called when user sends a message. Handle chat + analysis. Return new chat messages.
"""
if not message.strip():
return chat_history # ignore empty
updated_history = handle_chat(message, df, chat_history, api_key)
return updated_history
user_message.submit(
fn=on_user_message,
inputs=[user_message, df_state, chat_state, api_key_box],
outputs=chat_state
).then(
# After updating chat_state, reflect it in the chatbot
fn=lambda messages: messages,
inputs=chat_state,
outputs=chatbot
).then(
fn=lambda: "",
outputs=user_message
)
# Button to send message
send_btn = gr.Button("Send")
send_btn.click(
fn=on_user_message,
inputs=[user_message, df_state, chat_state, api_key_box],
outputs=chat_state
).then(
fn=lambda messages: messages,
inputs=chat_state,
outputs=chatbot
).then(
fn=lambda: "",
outputs=user_message
)
# Clear chat button
clear_btn = gr.Button("Clear Chat")
def clear_chat():
return [], []
clear_btn.click(
fn=clear_chat,
inputs=[],
outputs=[chat_state, chatbot]
)
return demo
demo = create_demo()
if __name__ == "__main__":
demo.launch(share=True)