|
import os |
|
import logging |
|
import pandas as pd |
|
import google.generativeai as genai |
|
import gradio as gr |
|
from typing import Dict, List, Any, Tuple |
|
import json |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import io |
|
import base64 |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class DataTools: |
|
"""Tools for data analysis that can be called by the AI""" |
|
def __init__(self, df: pd.DataFrame): |
|
self.df = df |
|
|
|
def describe_column(self, column: str) -> dict: |
|
"""Get statistical description of a column""" |
|
if column not in self.df.columns: |
|
return {"error": f"Column {column} not found"} |
|
|
|
stats = self.df[column].describe().to_dict() |
|
null_count = self.df[column].isnull().sum() |
|
return { |
|
"statistics": stats, |
|
"null_count": int(null_count), |
|
"dtype": str(self.df[column].dtype) |
|
} |
|
|
|
def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str: |
|
"""Create a visualization and return as base64 string""" |
|
try: |
|
plt.figure(figsize=(10, 6)) |
|
if plot_type == "histogram": |
|
sns.histplot(data=self.df, x=x) |
|
elif plot_type == "scatter": |
|
sns.scatterplot(data=self.df, x=x, y=y) |
|
elif plot_type == "boxplot": |
|
sns.boxplot(data=self.df, x=x, y=y) |
|
elif plot_type == "bar": |
|
sns.barplot(data=self.df, x=x, y=y) |
|
|
|
if title: |
|
plt.title(title) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
plt.close() |
|
|
|
|
|
return base64.b64encode(buf.read()).decode('utf-8') |
|
except Exception as e: |
|
return f"Error creating visualization: {str(e)}" |
|
|
|
def get_correlation(self, columns: List[str]) -> dict: |
|
"""Get correlation between specified columns""" |
|
try: |
|
corr = self.df[columns].corr().to_dict() |
|
return {"correlation_matrix": corr} |
|
except Exception as e: |
|
return {"error": f"Error calculating correlation: {str(e)}"} |
|
|
|
class DataAnalyzer: |
|
def __init__(self): |
|
self.model = None |
|
self.api_key = None |
|
self.system_prompt = None |
|
self.df = None |
|
self.tools = None |
|
|
|
def configure_api(self, api_key: str): |
|
"""Configure the Gemini API with the provided key""" |
|
try: |
|
genai.configure(api_key=api_key) |
|
self.model = genai.GenerativeModel('gemini-1.5-pro') |
|
self.api_key = api_key |
|
return True |
|
except Exception as e: |
|
logger.error(f"API configuration failed: {str(e)}") |
|
return False |
|
|
|
def load_data(self, file) -> Tuple[bool, str]: |
|
"""Load data from uploaded CSV file""" |
|
try: |
|
self.df = pd.read_csv(file.name) |
|
self.tools = DataTools(self.df) |
|
return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns" |
|
except Exception as e: |
|
logger.error(f"Data loading failed: {str(e)}") |
|
return False, f"Error loading data: {str(e)}" |
|
|
|
def get_data_info(self) -> Dict[str, Any]: |
|
"""Get information about the loaded data""" |
|
if self.df is None: |
|
return {"error": "No data loaded"} |
|
|
|
info = { |
|
"columns": list(self.df.columns), |
|
"rows": len(self.df), |
|
"sample": self.df.head(5).to_dict('records'), |
|
"dtypes": self.df.dtypes.astype(str).to_dict() |
|
} |
|
return info |
|
|
|
def analyze(self, query: str) -> Dict[str, Any]: |
|
"""Analyze data based on user query with structured output""" |
|
if self.model is None: |
|
return {"error": "Please configure API key first"} |
|
if self.df is None: |
|
return {"error": "Please upload a CSV file first"} |
|
|
|
data_info = self.get_data_info() |
|
|
|
|
|
prompt = f"""{self.system_prompt} |
|
|
|
Data Information: |
|
- Columns: {data_info['columns']} |
|
- Number of rows: {data_info['rows']} |
|
- Sample data: {json.dumps(data_info['sample'], indent=2)} |
|
|
|
Available Tools: |
|
1. describe_column(column: str) - Get statistical description of a column |
|
2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None) |
|
- Create visualizations (types: histogram, scatter, boxplot, bar) |
|
3. get_correlation(columns: List[str]) - Get correlation between columns |
|
|
|
User Query: {query} |
|
|
|
Please provide a structured analysis in the following JSON format: |
|
{ |
|
"answer": "Direct answer to the query", |
|
"tools_used": [ |
|
{ |
|
"tool": "tool_name", |
|
"parameters": {"param1": "value1"}, |
|
"purpose": "Why this tool was used" |
|
} |
|
], |
|
"insights": ["List of key insights"], |
|
"visualizations": ["List of suggested visualizations"], |
|
"recommendations": ["List of recommendations"], |
|
"limitations": ["Any limitations in the analysis"] |
|
} |
|
|
|
Important: |
|
- Be specific about which tools to use |
|
- Provide clear reasoning for each tool choice |
|
- Structure the output exactly as shown above |
|
""" |
|
try: |
|
|
|
response = self.model.generate_content(prompt) |
|
response_text = response.text |
|
|
|
try: |
|
|
|
structured_response = json.loads(response_text) |
|
|
|
|
|
results = {"response": structured_response, "tool_outputs": []} |
|
|
|
for tool_call in structured_response.get("tools_used", []): |
|
tool_name = tool_call["tool"] |
|
parameters = tool_call["parameters"] |
|
|
|
if hasattr(self.tools, tool_name): |
|
tool_method = getattr(self.tools, tool_name) |
|
tool_result = tool_method(**parameters) |
|
results["tool_outputs"].append({ |
|
"tool": tool_name, |
|
"parameters": parameters, |
|
"result": tool_result |
|
}) |
|
|
|
|
|
formatted_output = f"""## Analysis Results |
|
|
|
{structured_response['answer']} |
|
|
|
### Key Insights |
|
{"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])} |
|
|
|
### Visualizations |
|
{"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])} |
|
|
|
### Recommendations |
|
{"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])} |
|
|
|
### Limitations |
|
{"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])} |
|
|
|
--- |
|
Tool Outputs: |
|
{"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])} |
|
""" |
|
return formatted_output |
|
|
|
except json.JSONDecodeError: |
|
return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}" |
|
|
|
except Exception as e: |
|
logger.error(f"Analysis failed: {str(e)}") |
|
return f"Error during analysis: {str(e)}" |
|
|
|
def create_interface(): |
|
"""Create the Gradio interface""" |
|
analyzer = DataAnalyzer() |
|
|
|
def process_inputs(api_key: str, system_prompt: str, file, query: str): |
|
"""Process user inputs and return analysis results""" |
|
if api_key != analyzer.api_key: |
|
if not analyzer.configure_api(api_key): |
|
return "Failed to configure API. Please check your API key." |
|
|
|
analyzer.system_prompt = system_prompt |
|
|
|
if file is not None: |
|
success, message = analyzer.load_data(file) |
|
if not success: |
|
return message |
|
|
|
return analyzer.analyze(query) |
|
|
|
|
|
with gr.Blocks(title="Advanced Data Analysis Assistant") as interface: |
|
gr.Markdown("# Advanced Data Analysis Assistant") |
|
gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations") |
|
|
|
with gr.Row(): |
|
api_key_input = gr.Textbox( |
|
label="Gemini API Key", |
|
placeholder="Enter your Gemini API key", |
|
type="password" |
|
) |
|
|
|
with gr.Row(): |
|
system_prompt_input = gr.Textbox( |
|
label="System Prompt", |
|
placeholder="Enter system prompt for the AI", |
|
value="""You are an advanced data analysis expert. Analyze the provided data and answer the query. |
|
Focus on: |
|
1. Clear, structured analysis |
|
2. Statistical insights |
|
3. Appropriate visualizations |
|
4. Actionable recommendations""", |
|
lines=4 |
|
) |
|
|
|
with gr.Row(): |
|
file_input = gr.File( |
|
label="Upload CSV", |
|
file_types=[".csv"] |
|
) |
|
|
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="Analysis Query", |
|
placeholder="What would you like to know about the data?", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Analyze") |
|
|
|
with gr.Row(): |
|
output = gr.Markdown(label="Analysis Results") |
|
|
|
submit_btn.click( |
|
fn=process_inputs, |
|
inputs=[api_key_input, system_prompt_input, file_input, query_input], |
|
outputs=output |
|
) |
|
|
|
return interface |
|
|
|
def main(): |
|
"""Main application entry point""" |
|
try: |
|
interface = create_interface() |
|
interface.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |
|
except Exception as e: |
|
logger.error(f"Application startup failed: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
main() |