|
import os |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import torch |
|
from litellm import completion |
|
|
|
|
|
@dataclass |
|
class Tool: |
|
"""Simplified tool class""" |
|
name: str |
|
description: str |
|
func: Callable |
|
|
|
class MinimalAgent: |
|
"""Minimal agent implementation for demo purposes""" |
|
|
|
def __init__( |
|
self, |
|
model_id: str = "gpt-4o-mini", |
|
temperature: float = 0.7, |
|
max_steps: int = 5 |
|
): |
|
self.model_id = model_id |
|
self.temperature = temperature |
|
self.max_steps = max_steps |
|
self.tools: List[Tool] = [] |
|
|
|
def add_tool(self, name: str, description: str, func: Callable) -> None: |
|
"""Add a tool to the agent""" |
|
self.tools.append(Tool(name=name, description=description, func=func)) |
|
|
|
def run(self, prompt: str, **kwargs) -> str: |
|
"""Run the agent with a prompt""" |
|
messages = [ |
|
{"role": "system", "content": self._get_system_prompt()}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
try: |
|
response = completion( |
|
model=self.model_id, |
|
messages=messages, |
|
temperature=self.temperature, |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def _get_system_prompt(self) -> str: |
|
"""Get the system prompt including available tools""" |
|
tools_desc = "\n".join([ |
|
f"- {tool.name}: {tool.description}" |
|
for tool in self.tools |
|
]) |
|
|
|
return f"""You are a helpful AI agent that can analyze data and write code. |
|
|
|
Available tools: |
|
{tools_desc} |
|
|
|
Additional capabilities: |
|
- Data analysis with pandas, numpy |
|
- Visualization with matplotlib, seaborn |
|
- Machine learning with sklearn |
|
- Statistical analysis with scipy |
|
|
|
Provide clear explanations and code examples.""" |
|
|
|
|
|
def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str: |
|
"""Basic DataFrame analysis""" |
|
if analysis_type == "summary": |
|
return str(df.describe()) |
|
elif analysis_type == "info": |
|
buffer = [] |
|
df.info(buf=buffer) |
|
return "\n".join(buffer) |
|
return "Unknown analysis type" |
|
|
|
def plot_data(df: pd.DataFrame, plot_type: str) -> None: |
|
"""Basic plotting function""" |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
if plot_type == "correlation": |
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(df.corr(), annot=True) |
|
plt.title("Correlation Heatmap") |
|
elif plot_type == "distribution": |
|
df.hist(figsize=(15, 10)) |
|
plt.tight_layout() |
|
|
|
def process_file(file: gr.File) -> Optional[pd.DataFrame]: |
|
"""Process uploaded file into DataFrame""" |
|
if not file: |
|
return None |
|
|
|
try: |
|
if file.name.endswith('.csv'): |
|
return pd.read_csv(file.name) |
|
elif file.name.endswith(('.xlsx', '.xls')): |
|
return pd.read_excel(file.name) |
|
except Exception as e: |
|
print(f"Error reading file: {str(e)}") |
|
return None |
|
|
|
def analyze_data( |
|
file: gr.File, |
|
query: str, |
|
api_key: str, |
|
temperature: float = 0.7, |
|
) -> str: |
|
"""Process user request and generate analysis""" |
|
|
|
if not api_key: |
|
return "Error: Please provide an API key." |
|
|
|
if not file: |
|
return "Error: Please upload a file." |
|
|
|
try: |
|
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
agent = MinimalAgent( |
|
model_id="gpt-4o-mini", |
|
temperature=temperature |
|
) |
|
|
|
|
|
agent.add_tool( |
|
"analyze_dataframe", |
|
"Analyze DataFrame with various metrics", |
|
analyze_dataframe |
|
) |
|
agent.add_tool( |
|
"plot_data", |
|
"Create various plots from DataFrame", |
|
plot_data |
|
) |
|
|
|
|
|
df = process_file(file) |
|
if df is None: |
|
return "Error: Could not process file." |
|
|
|
|
|
file_info = f""" |
|
File: {file.name} |
|
Shape: {df.shape} |
|
Columns: {', '.join(df.columns)} |
|
|
|
Column Types: |
|
{chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])} |
|
""" |
|
|
|
|
|
prompt = f""" |
|
{file_info} |
|
|
|
The data is loaded in a pandas DataFrame called 'df'. |
|
|
|
User request: {query} |
|
|
|
Please analyze the data and provide: |
|
1. A clear explanation of your approach |
|
2. Code for the analysis |
|
3. Visualizations where relevant |
|
4. Key insights and findings |
|
""" |
|
|
|
return agent.run(prompt) |
|
|
|
except Exception as e: |
|
return f"Error occurred: {str(e)}" |
|
|
|
def create_interface(): |
|
"""Create Gradio interface""" |
|
|
|
with gr.Blocks(title="AI Data Analysis Assistant") as interface: |
|
gr.Markdown(""" |
|
# AI Data Analysis Assistant |
|
|
|
Upload your data file and ask questions in natural language. |
|
|
|
**Features:** |
|
- Data analysis and visualization |
|
- Statistical analysis |
|
- Machine learning capabilities |
|
|
|
**Note**: Requires your own GPT-4 API key. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
file = gr.File( |
|
label="Upload Data File", |
|
file_types=[".csv", ".xlsx", ".xls"] |
|
) |
|
query = gr.Textbox( |
|
label="What would you like to analyze?", |
|
placeholder="e.g., Create visualizations showing relationships between variables", |
|
lines=3 |
|
) |
|
api_key = gr.Textbox( |
|
label="API Key (Required)", |
|
placeholder="Your API key", |
|
type="password" |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1 |
|
) |
|
analyze_btn = gr.Button("Analyze") |
|
|
|
with gr.Column(): |
|
output = gr.Markdown(label="Output") |
|
|
|
analyze_btn.click( |
|
analyze_data, |
|
inputs=[file, query, api_key, temperature], |
|
outputs=output |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[None, "Show key statistics and create visualizations for numeric columns"], |
|
[None, "Find correlations and patterns in the data"], |
|
[None, "Identify outliers and unusual patterns"], |
|
[None, "Create summary visualizations of the main variables"], |
|
], |
|
inputs=[file, query] |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |