import os import logging import pandas as pd import google.generativeai as genai import gradio as gr from typing import Dict, List, Any, Tuple import json import matplotlib.pyplot as plt import seaborn as sns import io import base64 # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DataTools: """Tools for data analysis that can be called by the AI""" def __init__(self, df: pd.DataFrame): self.df = df def describe_column(self, column: str) -> dict: """Get statistical description of a column""" if column not in self.df.columns: return {"error": f"Column {column} not found"} stats = self.df[column].describe().to_dict() null_count = self.df[column].isnull().sum() return { "statistics": stats, "null_count": int(null_count), "dtype": str(self.df[column].dtype) } def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str: """Create a visualization and return as base64 string""" try: plt.figure(figsize=(10, 6)) if plot_type == "histogram": sns.histplot(data=self.df, x=x) elif plot_type == "scatter": sns.scatterplot(data=self.df, x=x, y=y) elif plot_type == "boxplot": sns.boxplot(data=self.df, x=x, y=y) elif plot_type == "bar": sns.barplot(data=self.df, x=x, y=y) if title: plt.title(title) # Save plot to bytes buffer buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close() # Convert to base64 return base64.b64encode(buf.read()).decode('utf-8') except Exception as e: return f"Error creating visualization: {str(e)}" def get_correlation(self, columns: List[str]) -> dict: """Get correlation between specified columns""" try: corr = self.df[columns].corr().to_dict() return {"correlation_matrix": corr} except Exception as e: return {"error": f"Error calculating correlation: {str(e)}"} class DataAnalyzer: def __init__(self): self.model = None self.api_key = None self.system_prompt = None self.df = None self.tools = None def configure_api(self, api_key: str): """Configure the Gemini API with the provided key""" try: genai.configure(api_key=api_key) self.model = genai.GenerativeModel('gemini-1.5-pro') self.api_key = api_key return True except Exception as e: logger.error(f"API configuration failed: {str(e)}") return False def load_data(self, file) -> Tuple[bool, str]: """Load data from uploaded CSV file""" try: self.df = pd.read_csv(file.name) self.tools = DataTools(self.df) return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns" except Exception as e: logger.error(f"Data loading failed: {str(e)}") return False, f"Error loading data: {str(e)}" def get_data_info(self) -> Dict[str, Any]: """Get information about the loaded data""" if self.df is None: return {"error": "No data loaded"} info = { "columns": list(self.df.columns), "rows": len(self.df), "sample": self.df.head(5).to_dict('records'), "dtypes": self.df.dtypes.astype(str).to_dict() } return info def analyze(self, query: str) -> Dict[str, Any]: """Analyze data based on user query with structured output""" if self.model is None: return {"error": "Please configure API key first"} if self.df is None: return {"error": "Please upload a CSV file first"} data_info = self.get_data_info() # Combine system prompt with data context and tool instructions prompt = f"""{self.system_prompt} Data Information: - Columns: {data_info['columns']} - Number of rows: {data_info['rows']} - Sample data: {json.dumps(data_info['sample'], indent=2)} Available Tools: 1. describe_column(column: str) - Get statistical description of a column 2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None) - Create visualizations (types: histogram, scatter, boxplot, bar) 3. get_correlation(columns: List[str]) - Get correlation between columns User Query: {query} Please provide a structured analysis in the following JSON format: { "answer": "Direct answer to the query", "tools_used": [ { "tool": "tool_name", "parameters": {"param1": "value1"}, "purpose": "Why this tool was used" } ], "insights": ["List of key insights"], "visualizations": ["List of suggested visualizations"], "recommendations": ["List of recommendations"], "limitations": ["Any limitations in the analysis"] } Important: - Be specific about which tools to use - Provide clear reasoning for each tool choice - Structure the output exactly as shown above """ try: # Get initial response from Gemini response = self.model.generate_content(prompt) response_text = response.text try: # Parse the response as JSON structured_response = json.loads(response_text) # Execute tool calls based on response results = {"response": structured_response, "tool_outputs": []} for tool_call in structured_response.get("tools_used", []): tool_name = tool_call["tool"] parameters = tool_call["parameters"] if hasattr(self.tools, tool_name): tool_method = getattr(self.tools, tool_name) tool_result = tool_method(**parameters) results["tool_outputs"].append({ "tool": tool_name, "parameters": parameters, "result": tool_result }) # Format output for Gradio formatted_output = f"""## Analysis Results {structured_response['answer']} ### Key Insights {"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])} ### Visualizations {"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])} ### Recommendations {"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])} ### Limitations {"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])} --- Tool Outputs: {"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])} """ return formatted_output except json.JSONDecodeError: return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}" except Exception as e: logger.error(f"Analysis failed: {str(e)}") return f"Error during analysis: {str(e)}" def create_interface(): """Create the Gradio interface""" analyzer = DataAnalyzer() def process_inputs(api_key: str, system_prompt: str, file, query: str): """Process user inputs and return analysis results""" if api_key != analyzer.api_key: if not analyzer.configure_api(api_key): return "Failed to configure API. Please check your API key." analyzer.system_prompt = system_prompt if file is not None: success, message = analyzer.load_data(file) if not success: return message return analyzer.analyze(query) # Create Gradio interface with gr.Blocks(title="Advanced Data Analysis Assistant") as interface: gr.Markdown("# Advanced Data Analysis Assistant") gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations") with gr.Row(): api_key_input = gr.Textbox( label="Gemini API Key", placeholder="Enter your Gemini API key", type="password" ) with gr.Row(): system_prompt_input = gr.Textbox( label="System Prompt", placeholder="Enter system prompt for the AI", value="""You are an advanced data analysis expert. Analyze the provided data and answer the query. Focus on: 1. Clear, structured analysis 2. Statistical insights 3. Appropriate visualizations 4. Actionable recommendations""", lines=4 ) with gr.Row(): file_input = gr.File( label="Upload CSV", file_types=[".csv"] ) with gr.Row(): query_input = gr.Textbox( label="Analysis Query", placeholder="What would you like to know about the data?", lines=2 ) with gr.Row(): submit_btn = gr.Button("Analyze") with gr.Row(): output = gr.Markdown(label="Analysis Results") submit_btn.click( fn=process_inputs, inputs=[api_key_input, system_prompt_input, file_input, query_input], outputs=output ) return interface def main(): """Main application entry point""" try: interface = create_interface() interface.launch( share=True, server_name="0.0.0.0", server_port=7860 ) except Exception as e: logger.error(f"Application startup failed: {str(e)}") raise if __name__ == "__main__": main()