jzou19950715's picture
Update app.py
4ad3262 verified
raw
history blame
7.22 kB
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
import gradio as gr
import pandas as pd
import torch
from litellm import completion
# Agent Implementation
@dataclass
class Tool:
"""Simplified tool class"""
name: str
description: str
func: Callable
class MinimalAgent:
"""Minimal agent implementation for demo purposes"""
def __init__(
self,
model_id: str = "gpt-4o-mini",
temperature: float = 0.7,
max_steps: int = 5
):
self.model_id = model_id
self.temperature = temperature
self.max_steps = max_steps
self.tools: List[Tool] = []
def add_tool(self, name: str, description: str, func: Callable) -> None:
"""Add a tool to the agent"""
self.tools.append(Tool(name=name, description=description, func=func))
def run(self, prompt: str, **kwargs) -> str:
"""Run the agent with a prompt"""
messages = [
{"role": "system", "content": self._get_system_prompt()},
{"role": "user", "content": prompt}
]
try:
response = completion(
model=self.model_id,
messages=messages,
temperature=self.temperature,
)
return response.choices[0].message.content
except Exception as e:
return f"Error: {str(e)}"
def _get_system_prompt(self) -> str:
"""Get the system prompt including available tools"""
tools_desc = "\n".join([
f"- {tool.name}: {tool.description}"
for tool in self.tools
])
return f"""You are a helpful AI agent that can analyze data and write code.
Available tools:
{tools_desc}
Additional capabilities:
- Data analysis with pandas, numpy
- Visualization with matplotlib, seaborn
- Machine learning with sklearn
- Statistical analysis with scipy
Provide clear explanations and code examples."""
# Analysis Functions
def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
"""Basic DataFrame analysis"""
if analysis_type == "summary":
return str(df.describe())
elif analysis_type == "info":
buffer = []
df.info(buf=buffer)
return "\n".join(buffer)
return "Unknown analysis type"
def plot_data(df: pd.DataFrame, plot_type: str) -> None:
"""Basic plotting function"""
import matplotlib.pyplot as plt
import seaborn as sns
if plot_type == "correlation":
plt.figure(figsize=(10, 8))
sns.heatmap(df.corr(), annot=True)
plt.title("Correlation Heatmap")
elif plot_type == "distribution":
df.hist(figsize=(15, 10))
plt.tight_layout()
def process_file(file: gr.File) -> Optional[pd.DataFrame]:
"""Process uploaded file into DataFrame"""
if not file:
return None
try:
if file.name.endswith('.csv'):
return pd.read_csv(file.name)
elif file.name.endswith(('.xlsx', '.xls')):
return pd.read_excel(file.name)
except Exception as e:
print(f"Error reading file: {str(e)}")
return None
def analyze_data(
file: gr.File,
query: str,
api_key: str,
temperature: float = 0.7,
) -> str:
"""Process user request and generate analysis"""
if not api_key:
return "Error: Please provide an API key."
if not file:
return "Error: Please upload a file."
try:
# Set up environment
os.environ["OPENAI_API_KEY"] = api_key
# Create agent
agent = MinimalAgent(
model_id="gpt-4o-mini",
temperature=temperature
)
# Add tools
agent.add_tool(
"analyze_dataframe",
"Analyze DataFrame with various metrics",
analyze_dataframe
)
agent.add_tool(
"plot_data",
"Create various plots from DataFrame",
plot_data
)
# Process file
df = process_file(file)
if df is None:
return "Error: Could not process file."
# Build context
file_info = f"""
File: {file.name}
Shape: {df.shape}
Columns: {', '.join(df.columns)}
Column Types:
{chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])}
"""
# Run analysis
prompt = f"""
{file_info}
The data is loaded in a pandas DataFrame called 'df'.
User request: {query}
Please analyze the data and provide:
1. A clear explanation of your approach
2. Code for the analysis
3. Visualizations where relevant
4. Key insights and findings
"""
return agent.run(prompt)
except Exception as e:
return f"Error occurred: {str(e)}"
def create_interface():
"""Create Gradio interface"""
with gr.Blocks(title="AI Data Analysis Assistant") as interface:
gr.Markdown("""
# AI Data Analysis Assistant
Upload your data file and ask questions in natural language.
**Features:**
- Data analysis and visualization
- Statistical analysis
- Machine learning capabilities
**Note**: Requires your own GPT-4 API key.
""")
with gr.Row():
with gr.Column():
file = gr.File(
label="Upload Data File",
file_types=[".csv", ".xlsx", ".xls"]
)
query = gr.Textbox(
label="What would you like to analyze?",
placeholder="e.g., Create visualizations showing relationships between variables",
lines=3
)
api_key = gr.Textbox(
label="API Key (Required)",
placeholder="Your API key",
type="password"
)
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1
)
analyze_btn = gr.Button("Analyze")
with gr.Column():
output = gr.Markdown(label="Output")
analyze_btn.click(
analyze_data,
inputs=[file, query, api_key, temperature],
outputs=output
)
gr.Examples(
examples=[
[None, "Show key statistics and create visualizations for numeric columns"],
[None, "Find correlations and patterns in the data"],
[None, "Identify outliers and unusual patterns"],
[None, "Create summary visualizations of the main variables"],
],
inputs=[file, query]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()