File size: 10,288 Bytes
37336a7
eb04de8
dfc517e
 
 
 
 
4dec0f2
 
 
 
7cdbd20
eb04de8
 
 
 
4dec0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc517e
 
 
 
 
 
4dec0f2
dfc517e
 
4dec0f2
eb04de8
dfc517e
 
 
 
eb04de8
dfc517e
 
 
 
 
eb04de8
dfc517e
4dec0f2
dfc517e
eb04de8
dfc517e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dec0f2
 
dfc517e
4dec0f2
dfc517e
4dec0f2
dfc517e
 
 
4dec0f2
dfc517e
 
 
 
 
 
 
4dec0f2
 
 
 
 
 
dfc517e
 
4dec0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc517e
eb04de8
4dec0f2
dfc517e
4dec0f2
eb04de8
4dec0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb04de8
 
dfc517e
4dec0f2
eb04de8
dfc517e
 
 
eb04de8
dfc517e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dec0f2
 
 
e7486fb
882008c
dfc517e
 
 
 
eb04de8
dfc517e
 
 
 
 
4dec0f2
 
 
 
 
 
 
eb04de8
 
dfc517e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb04de8
dfc517e
 
 
 
 
 
 
069bc6a
 
eb04de8
 
 
 
 
dfc517e
eb04de8
dfc517e
eb04de8
dfc517e
eb04de8
 
dfc517e
eb04de8
37336a7
 
eb04de8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import os
import logging
import pandas as pd
import google.generativeai as genai
import gradio as gr
from typing import Dict, List, Any, Tuple
import json
import matplotlib.pyplot as plt
import seaborn as sns
import io
import base64

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataTools:
    """Tools for data analysis that can be called by the AI"""
    def __init__(self, df: pd.DataFrame):
        self.df = df
        
    def describe_column(self, column: str) -> dict:
        """Get statistical description of a column"""
        if column not in self.df.columns:
            return {"error": f"Column {column} not found"}
        
        stats = self.df[column].describe().to_dict()
        null_count = self.df[column].isnull().sum()
        return {
            "statistics": stats,
            "null_count": int(null_count),
            "dtype": str(self.df[column].dtype)
        }
    
    def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str:
        """Create a visualization and return as base64 string"""
        try:
            plt.figure(figsize=(10, 6))
            if plot_type == "histogram":
                sns.histplot(data=self.df, x=x)
            elif plot_type == "scatter":
                sns.scatterplot(data=self.df, x=x, y=y)
            elif plot_type == "boxplot":
                sns.boxplot(data=self.df, x=x, y=y)
            elif plot_type == "bar":
                sns.barplot(data=self.df, x=x, y=y)
                
            if title:
                plt.title(title)
            
            # Save plot to bytes buffer
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            plt.close()
            
            # Convert to base64
            return base64.b64encode(buf.read()).decode('utf-8')
        except Exception as e:
            return f"Error creating visualization: {str(e)}"
    
    def get_correlation(self, columns: List[str]) -> dict:
        """Get correlation between specified columns"""
        try:
            corr = self.df[columns].corr().to_dict()
            return {"correlation_matrix": corr}
        except Exception as e:
            return {"error": f"Error calculating correlation: {str(e)}"}

class DataAnalyzer:
    def __init__(self):
        self.model = None
        self.api_key = None
        self.system_prompt = None
        self.df = None
        self.tools = None
        
    def configure_api(self, api_key: str):
        """Configure the Gemini API with the provided key"""
        try:
            genai.configure(api_key=api_key)
            self.model = genai.GenerativeModel('gemini-1.5-pro')
            self.api_key = api_key
            return True
        except Exception as e:
            logger.error(f"API configuration failed: {str(e)}")
            return False

    def load_data(self, file) -> Tuple[bool, str]:
        """Load data from uploaded CSV file"""
        try:
            self.df = pd.read_csv(file.name)
            self.tools = DataTools(self.df)
            return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
        except Exception as e:
            logger.error(f"Data loading failed: {str(e)}")
            return False, f"Error loading data: {str(e)}"

    def get_data_info(self) -> Dict[str, Any]:
        """Get information about the loaded data"""
        if self.df is None:
            return {"error": "No data loaded"}
        
        info = {
            "columns": list(self.df.columns),
            "rows": len(self.df),
            "sample": self.df.head(5).to_dict('records'),
            "dtypes": self.df.dtypes.astype(str).to_dict()
        }
        return info

    def analyze(self, query: str) -> Dict[str, Any]:
        """Analyze data based on user query with structured output"""
        if self.model is None:
            return {"error": "Please configure API key first"}
        if self.df is None:
            return {"error": "Please upload a CSV file first"}

        data_info = self.get_data_info()
        
        # Combine system prompt with data context and tool instructions
        prompt = f"""{self.system_prompt}

Data Information:
- Columns: {data_info['columns']}
- Number of rows: {data_info['rows']}
- Sample data: {json.dumps(data_info['sample'], indent=2)}

Available Tools:
1. describe_column(column: str) - Get statistical description of a column
2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None) 
   - Create visualizations (types: histogram, scatter, boxplot, bar)
3. get_correlation(columns: List[str]) - Get correlation between columns

User Query: {query}

Please provide a structured analysis in the following JSON format:
{
    "answer": "Direct answer to the query",
    "tools_used": [
        {
            "tool": "tool_name",
            "parameters": {"param1": "value1"},
            "purpose": "Why this tool was used"
        }
    ],
    "insights": ["List of key insights"],
    "visualizations": ["List of suggested visualizations"],
    "recommendations": ["List of recommendations"],
    "limitations": ["Any limitations in the analysis"]
}

Important:
- Be specific about which tools to use
- Provide clear reasoning for each tool choice
- Structure the output exactly as shown above
"""
        try:
            # Get initial response from Gemini
            response = self.model.generate_content(prompt)
            response_text = response.text
            
            try:
                # Parse the response as JSON
                structured_response = json.loads(response_text)
                
                # Execute tool calls based on response
                results = {"response": structured_response, "tool_outputs": []}
                
                for tool_call in structured_response.get("tools_used", []):
                    tool_name = tool_call["tool"]
                    parameters = tool_call["parameters"]
                    
                    if hasattr(self.tools, tool_name):
                        tool_method = getattr(self.tools, tool_name)
                        tool_result = tool_method(**parameters)
                        results["tool_outputs"].append({
                            "tool": tool_name,
                            "parameters": parameters,
                            "result": tool_result
                        })
                
                # Format output for Gradio
                formatted_output = f"""## Analysis Results

{structured_response['answer']}

### Key Insights
{"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])}

### Visualizations
{"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])}

### Recommendations
{"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])}

### Limitations
{"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])}

---
Tool Outputs:
{"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])}
"""
                return formatted_output
                
            except json.JSONDecodeError:
                return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}"
            
        except Exception as e:
            logger.error(f"Analysis failed: {str(e)}")
            return f"Error during analysis: {str(e)}"

def create_interface():
    """Create the Gradio interface"""
    analyzer = DataAnalyzer()
    
    def process_inputs(api_key: str, system_prompt: str, file, query: str):
        """Process user inputs and return analysis results"""
        if api_key != analyzer.api_key:
            if not analyzer.configure_api(api_key):
                return "Failed to configure API. Please check your API key."
        
        analyzer.system_prompt = system_prompt
        
        if file is not None:
            success, message = analyzer.load_data(file)
            if not success:
                return message
        
        return analyzer.analyze(query)

    # Create Gradio interface
    with gr.Blocks(title="Advanced Data Analysis Assistant") as interface:
        gr.Markdown("# Advanced Data Analysis Assistant")
        gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations")
        
        with gr.Row():
            api_key_input = gr.Textbox(
                label="Gemini API Key",
                placeholder="Enter your Gemini API key",
                type="password"
            )
        
        with gr.Row():
            system_prompt_input = gr.Textbox(
                label="System Prompt",
                placeholder="Enter system prompt for the AI",
                value="""You are an advanced data analysis expert. Analyze the provided data and answer the query.
Focus on:
1. Clear, structured analysis
2. Statistical insights
3. Appropriate visualizations
4. Actionable recommendations""",
                lines=4
            )
        
        with gr.Row():
            file_input = gr.File(
                label="Upload CSV",
                file_types=[".csv"]
            )
        
        with gr.Row():
            query_input = gr.Textbox(
                label="Analysis Query",
                placeholder="What would you like to know about the data?",
                lines=2
            )
        
        with gr.Row():
            submit_btn = gr.Button("Analyze")
        
        with gr.Row():
            output = gr.Markdown(label="Analysis Results")
        
        submit_btn.click(
            fn=process_inputs,
            inputs=[api_key_input, system_prompt_input, file_input, query_input],
            outputs=output
        )
    
    return interface

def main():
    """Main application entry point"""
    try:
        interface = create_interface()
        interface.launch(
            share=True,
            server_name="0.0.0.0",
            server_port=7860
        )
    except Exception as e:
        logger.error(f"Application startup failed: {str(e)}")
        raise

if __name__ == "__main__":
    main()