|
import base64 |
|
import io |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
from litellm import completion |
|
|
|
|
|
|
|
class CodeEnvironment: |
|
"""Safe environment for executing code with data analysis capabilities""" |
|
|
|
def __init__(self): |
|
self.globals = { |
|
'pd': pd, |
|
'np': np, |
|
'plt': plt, |
|
'sns': sns, |
|
} |
|
self.locals = {} |
|
|
|
def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]: |
|
"""Execute code and capture outputs""" |
|
if df is not None: |
|
self.globals['df'] = df |
|
|
|
|
|
output_buffer = io.StringIO() |
|
result = {'output': '', 'figures': [], 'error': None} |
|
|
|
try: |
|
|
|
exec(code, self.globals, self.locals) |
|
|
|
|
|
for i in plt.get_fignums(): |
|
fig = plt.figure(i) |
|
buf = io.BytesIO() |
|
fig.savefig(buf, format='png') |
|
buf.seek(0) |
|
img_str = base64.b64encode(buf.read()).decode() |
|
result['figures'].append(f"data:image/png;base64,{img_str}") |
|
plt.close(fig) |
|
|
|
|
|
result['output'] = output_buffer.getvalue() |
|
|
|
except Exception as e: |
|
result['error'] = str(e) |
|
|
|
finally: |
|
output_buffer.close() |
|
|
|
return result |
|
|
|
@dataclass |
|
class Tool: |
|
"""Tool for data analysis""" |
|
name: str |
|
description: str |
|
func: Callable |
|
|
|
class AnalysisAgent: |
|
"""Agent that can analyze data and execute code""" |
|
|
|
def __init__( |
|
self, |
|
model_id: str = "gpt-4o-mini", |
|
temperature: float = 0.7, |
|
): |
|
self.model_id = model_id |
|
self.temperature = temperature |
|
self.tools: List[Tool] = [] |
|
self.code_env = CodeEnvironment() |
|
|
|
def add_tool(self, name: str, description: str, func: Callable) -> None: |
|
"""Add a tool to the agent""" |
|
self.tools.append(Tool(name=name, description=description, func=func)) |
|
|
|
def run(self, prompt: str, df: pd.DataFrame = None) -> str: |
|
"""Run analysis with code execution""" |
|
messages = [ |
|
{"role": "system", "content": self._get_system_prompt()}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
try: |
|
|
|
response = completion( |
|
model=self.model_id, |
|
messages=messages, |
|
temperature=self.temperature, |
|
) |
|
analysis = response.choices[0].message.content |
|
|
|
|
|
code_blocks = self._extract_code(analysis) |
|
|
|
|
|
results = [] |
|
for code in code_blocks: |
|
result = self.code_env.execute(code, df) |
|
if result['error']: |
|
results.append(f"Error executing code: {result['error']}") |
|
else: |
|
|
|
if result['output']: |
|
results.append(result['output']) |
|
for fig in result['figures']: |
|
results.append(f"") |
|
|
|
|
|
return analysis + "\n\n" + "\n".join(results) |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def _get_system_prompt(self) -> str: |
|
"""Get system prompt with tools and capabilities""" |
|
tools_desc = "\n".join([ |
|
f"- {tool.name}: {tool.description}" |
|
for tool in self.tools |
|
]) |
|
|
|
return f"""You are a data analysis assistant. |
|
|
|
Available tools: |
|
{tools_desc} |
|
Capabilities: |
|
- Data analysis (pandas, numpy) |
|
- Visualization (matplotlib, seaborn) |
|
- Statistical analysis (scipy) |
|
- Machine learning (sklearn) |
|
When writing code: |
|
- Use markdown code blocks |
|
- Create clear visualizations |
|
- Include explanations |
|
- Handle errors gracefully |
|
""" |
|
|
|
@staticmethod |
|
def _extract_code(text: str) -> List[str]: |
|
"""Extract Python code blocks from markdown""" |
|
import re |
|
pattern = r'```python\n(.*?)```' |
|
return re.findall(pattern, text, re.DOTALL) |
|
|
|
def process_file(file: gr.File) -> Optional[pd.DataFrame]: |
|
"""Process uploaded file into DataFrame""" |
|
if not file: |
|
return None |
|
|
|
try: |
|
if file.name.endswith('.csv'): |
|
return pd.read_csv(file.name) |
|
elif file.name.endswith(('.xlsx', '.xls')): |
|
return pd.read_excel(file.name) |
|
except Exception as e: |
|
print(f"Error reading file: {str(e)}") |
|
return None |
|
|
|
def analyze_data( |
|
file: gr.File, |
|
query: str, |
|
api_key: str, |
|
temperature: float = 0.7, |
|
) -> str: |
|
"""Process user request and generate analysis""" |
|
|
|
if not api_key: |
|
return "Error: Please provide an API key." |
|
|
|
if not file: |
|
return "Error: Please upload a file." |
|
|
|
try: |
|
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
agent = AnalysisAgent( |
|
model_id="gpt-4o-mini", |
|
temperature=temperature |
|
) |
|
|
|
|
|
df = process_file(file) |
|
if df is None: |
|
return "Error: Could not process file." |
|
|
|
|
|
file_info = f""" |
|
File: {file.name} |
|
Shape: {df.shape} |
|
Columns: {', '.join(df.columns)} |
|
|
|
Column Types: |
|
{chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])} |
|
""" |
|
|
|
|
|
prompt = f""" |
|
{file_info} |
|
|
|
The data is loaded in a pandas DataFrame called 'df'. |
|
|
|
User request: {query} |
|
|
|
Please analyze the data and provide: |
|
1. Key insights and findings |
|
2. Whenever the user request is unclear, proactively interpret them such that it becomes analyzable. |
|
""" |
|
|
|
return agent.run(prompt, df=df) |
|
|
|
except Exception as e: |
|
return f"Error occurred: {str(e)}" |
|
|
|
def create_interface(): |
|
"""Create Gradio interface""" |
|
|
|
with gr.Blocks(title="AI Data Analysis Assistant") as interface: |
|
gr.Markdown(""" |
|
# AI Data Analysis Assistant |
|
|
|
Upload your data file and get AI-powered analysis with visualizations. |
|
|
|
**Features:** |
|
- Data analysis and visualization |
|
- Statistical analysis |
|
- Machine learning capabilities |
|
|
|
**Note**: Requires your own OpenAi API key. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
file = gr.File( |
|
label="Upload Data File", |
|
file_types=[".csv", ".xlsx", ".xls"] |
|
) |
|
query = gr.Textbox( |
|
label="What would you like to analyze?", |
|
placeholder="e.g., Create visualizations showing relationships between variables", |
|
lines=3 |
|
) |
|
api_key = gr.Textbox( |
|
label="API Key (Required)", |
|
placeholder="Your API key", |
|
type="password" |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1 |
|
) |
|
analyze_btn = gr.Button("Analyze") |
|
|
|
with gr.Column(): |
|
output = gr.Markdown(label="Output") |
|
|
|
analyze_btn.click( |
|
analyze_data, |
|
inputs=[file, query, api_key, temperature], |
|
outputs=output |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[None, "Show the distribution of values and key statistics"], |
|
[None, "Create a correlation analysis with heatmap"], |
|
[None, "Identify and visualize any outliers in the data"], |
|
[None, "Generate summary plots for the main variables"], |
|
], |
|
inputs=[file, query] |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |