|
import os |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
|
|
from minimal_agent import MinimalAgent |
|
|
|
def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str: |
|
"""Basic DataFrame analysis""" |
|
if analysis_type == "summary": |
|
return str(df.describe()) |
|
elif analysis_type == "info": |
|
buffer = [] |
|
df.info(buf=buffer) |
|
return "\n".join(buffer) |
|
return "Unknown analysis type" |
|
|
|
def plot_data(df: pd.DataFrame, plot_type: str) -> None: |
|
"""Basic plotting function""" |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
if plot_type == "correlation": |
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(df.corr(), annot=True) |
|
plt.title("Correlation Heatmap") |
|
elif plot_type == "distribution": |
|
df.hist(figsize=(15, 10)) |
|
plt.tight_layout() |
|
|
|
def process_file(file: gr.File) -> Optional[pd.DataFrame]: |
|
"""Process uploaded file into DataFrame""" |
|
if not file: |
|
return None |
|
|
|
try: |
|
if file.name.endswith('.csv'): |
|
return pd.read_csv(file.name) |
|
elif file.name.endswith(('.xlsx', '.xls')): |
|
return pd.read_excel(file.name) |
|
except Exception as e: |
|
print(f"Error reading file: {str(e)}") |
|
return None |
|
|
|
def analyze_data( |
|
file: gr.File, |
|
query: str, |
|
api_key: str, |
|
temperature: float = 0.7, |
|
) -> str: |
|
"""Process user request and generate analysis""" |
|
|
|
if not api_key: |
|
return "Error: Please provide an API key." |
|
|
|
if not file: |
|
return "Error: Please upload a file." |
|
|
|
try: |
|
|
|
os.environ["OPENAI_API_KEY"] = api_key |
|
|
|
|
|
agent = MinimalAgent( |
|
model_id="gpt-4o-mini", |
|
temperature=temperature |
|
) |
|
|
|
|
|
agent.add_tool( |
|
"analyze_dataframe", |
|
"Analyze DataFrame with various metrics", |
|
analyze_dataframe |
|
) |
|
agent.add_tool( |
|
"plot_data", |
|
"Create various plots from DataFrame", |
|
plot_data |
|
) |
|
|
|
|
|
df = process_file(file) |
|
if df is None: |
|
return "Error: Could not process file." |
|
|
|
|
|
file_info = f""" |
|
File: {file.name} |
|
Shape: {df.shape} |
|
Columns: {', '.join(df.columns)} |
|
|
|
Column Types: |
|
{chr(10).join([f'- {col}: {dtype}' for col, dtype in df.dtypes.items()])} |
|
""" |
|
|
|
|
|
prompt = f""" |
|
{file_info} |
|
|
|
The data is loaded in a pandas DataFrame called 'df'. |
|
|
|
User request: {query} |
|
|
|
Please analyze the data and provide: |
|
1. A clear explanation of your approach |
|
2. Code for the analysis |
|
3. Visualizations where relevant |
|
4. Key insights and findings |
|
""" |
|
|
|
return agent.run(prompt) |
|
|
|
except Exception as e: |
|
return f"Error occurred: {str(e)}" |
|
|
|
def create_interface(): |
|
"""Create Gradio interface""" |
|
|
|
with gr.Blocks(title="AI Data Analysis Assistant") as interface: |
|
gr.Markdown(""" |
|
# AI Data Analysis Assistant |
|
|
|
Upload your data file and ask questions in natural language. |
|
|
|
**Features:** |
|
- Data analysis and visualization |
|
- Statistical analysis |
|
- Machine learning capabilities |
|
|
|
**Note**: Requires your own GPT-4 API key. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
file = gr.File( |
|
label="Upload Data File", |
|
file_types=[".csv", ".xlsx", ".xls"] |
|
) |
|
query = gr.Textbox( |
|
label="What would you like to analyze?", |
|
placeholder="e.g., Create visualizations showing relationships between variables", |
|
lines=3 |
|
) |
|
api_key = gr.Textbox( |
|
label="API Key (Required)", |
|
placeholder="Your API key", |
|
type="password" |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1 |
|
) |
|
analyze_btn = gr.Button("Analyze") |
|
|
|
with gr.Column(): |
|
output = gr.Markdown(label="Output") |
|
|
|
analyze_btn.click( |
|
analyze_data, |
|
inputs=[file, query, api_key, temperature], |
|
outputs=output |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[None, "Show key statistics and create visualizations for numeric columns"], |
|
[None, "Find correlations and patterns in the data"], |
|
[None, "Identify outliers and unusual patterns"], |
|
[None, "Create summary visualizations of the main variables"], |
|
], |
|
inputs=[file, query] |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |