jzou19950715's picture
Update app.py
4dec0f2 verified
raw
history blame
10.3 kB
import os
import logging
import pandas as pd
import google.generativeai as genai
import gradio as gr
from typing import Dict, List, Any, Tuple
import json
import matplotlib.pyplot as plt
import seaborn as sns
import io
import base64
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DataTools:
"""Tools for data analysis that can be called by the AI"""
def __init__(self, df: pd.DataFrame):
self.df = df
def describe_column(self, column: str) -> dict:
"""Get statistical description of a column"""
if column not in self.df.columns:
return {"error": f"Column {column} not found"}
stats = self.df[column].describe().to_dict()
null_count = self.df[column].isnull().sum()
return {
"statistics": stats,
"null_count": int(null_count),
"dtype": str(self.df[column].dtype)
}
def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str:
"""Create a visualization and return as base64 string"""
try:
plt.figure(figsize=(10, 6))
if plot_type == "histogram":
sns.histplot(data=self.df, x=x)
elif plot_type == "scatter":
sns.scatterplot(data=self.df, x=x, y=y)
elif plot_type == "boxplot":
sns.boxplot(data=self.df, x=x, y=y)
elif plot_type == "bar":
sns.barplot(data=self.df, x=x, y=y)
if title:
plt.title(title)
# Save plot to bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert to base64
return base64.b64encode(buf.read()).decode('utf-8')
except Exception as e:
return f"Error creating visualization: {str(e)}"
def get_correlation(self, columns: List[str]) -> dict:
"""Get correlation between specified columns"""
try:
corr = self.df[columns].corr().to_dict()
return {"correlation_matrix": corr}
except Exception as e:
return {"error": f"Error calculating correlation: {str(e)}"}
class DataAnalyzer:
def __init__(self):
self.model = None
self.api_key = None
self.system_prompt = None
self.df = None
self.tools = None
def configure_api(self, api_key: str):
"""Configure the Gemini API with the provided key"""
try:
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel('gemini-1.5-pro')
self.api_key = api_key
return True
except Exception as e:
logger.error(f"API configuration failed: {str(e)}")
return False
def load_data(self, file) -> Tuple[bool, str]:
"""Load data from uploaded CSV file"""
try:
self.df = pd.read_csv(file.name)
self.tools = DataTools(self.df)
return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
except Exception as e:
logger.error(f"Data loading failed: {str(e)}")
return False, f"Error loading data: {str(e)}"
def get_data_info(self) -> Dict[str, Any]:
"""Get information about the loaded data"""
if self.df is None:
return {"error": "No data loaded"}
info = {
"columns": list(self.df.columns),
"rows": len(self.df),
"sample": self.df.head(5).to_dict('records'),
"dtypes": self.df.dtypes.astype(str).to_dict()
}
return info
def analyze(self, query: str) -> Dict[str, Any]:
"""Analyze data based on user query with structured output"""
if self.model is None:
return {"error": "Please configure API key first"}
if self.df is None:
return {"error": "Please upload a CSV file first"}
data_info = self.get_data_info()
# Combine system prompt with data context and tool instructions
prompt = f"""{self.system_prompt}
Data Information:
- Columns: {data_info['columns']}
- Number of rows: {data_info['rows']}
- Sample data: {json.dumps(data_info['sample'], indent=2)}
Available Tools:
1. describe_column(column: str) - Get statistical description of a column
2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None)
- Create visualizations (types: histogram, scatter, boxplot, bar)
3. get_correlation(columns: List[str]) - Get correlation between columns
User Query: {query}
Please provide a structured analysis in the following JSON format:
{
"answer": "Direct answer to the query",
"tools_used": [
{
"tool": "tool_name",
"parameters": {"param1": "value1"},
"purpose": "Why this tool was used"
}
],
"insights": ["List of key insights"],
"visualizations": ["List of suggested visualizations"],
"recommendations": ["List of recommendations"],
"limitations": ["Any limitations in the analysis"]
}
Important:
- Be specific about which tools to use
- Provide clear reasoning for each tool choice
- Structure the output exactly as shown above
"""
try:
# Get initial response from Gemini
response = self.model.generate_content(prompt)
response_text = response.text
try:
# Parse the response as JSON
structured_response = json.loads(response_text)
# Execute tool calls based on response
results = {"response": structured_response, "tool_outputs": []}
for tool_call in structured_response.get("tools_used", []):
tool_name = tool_call["tool"]
parameters = tool_call["parameters"]
if hasattr(self.tools, tool_name):
tool_method = getattr(self.tools, tool_name)
tool_result = tool_method(**parameters)
results["tool_outputs"].append({
"tool": tool_name,
"parameters": parameters,
"result": tool_result
})
# Format output for Gradio
formatted_output = f"""## Analysis Results
{structured_response['answer']}
### Key Insights
{"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])}
### Visualizations
{"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])}
### Recommendations
{"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])}
### Limitations
{"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])}
---
Tool Outputs:
{"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])}
"""
return formatted_output
except json.JSONDecodeError:
return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}"
except Exception as e:
logger.error(f"Analysis failed: {str(e)}")
return f"Error during analysis: {str(e)}"
def create_interface():
"""Create the Gradio interface"""
analyzer = DataAnalyzer()
def process_inputs(api_key: str, system_prompt: str, file, query: str):
"""Process user inputs and return analysis results"""
if api_key != analyzer.api_key:
if not analyzer.configure_api(api_key):
return "Failed to configure API. Please check your API key."
analyzer.system_prompt = system_prompt
if file is not None:
success, message = analyzer.load_data(file)
if not success:
return message
return analyzer.analyze(query)
# Create Gradio interface
with gr.Blocks(title="Advanced Data Analysis Assistant") as interface:
gr.Markdown("# Advanced Data Analysis Assistant")
gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations")
with gr.Row():
api_key_input = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API key",
type="password"
)
with gr.Row():
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Enter system prompt for the AI",
value="""You are an advanced data analysis expert. Analyze the provided data and answer the query.
Focus on:
1. Clear, structured analysis
2. Statistical insights
3. Appropriate visualizations
4. Actionable recommendations""",
lines=4
)
with gr.Row():
file_input = gr.File(
label="Upload CSV",
file_types=[".csv"]
)
with gr.Row():
query_input = gr.Textbox(
label="Analysis Query",
placeholder="What would you like to know about the data?",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Analyze")
with gr.Row():
output = gr.Markdown(label="Analysis Results")
submit_btn.click(
fn=process_inputs,
inputs=[api_key_input, system_prompt_input, file_input, query_input],
outputs=output
)
return interface
def main():
"""Main application entry point"""
try:
interface = create_interface()
interface.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
raise
if __name__ == "__main__":
main()