File size: 10,288 Bytes
37336a7 eb04de8 dfc517e 4dec0f2 7cdbd20 eb04de8 4dec0f2 dfc517e 4dec0f2 dfc517e 4dec0f2 eb04de8 dfc517e eb04de8 dfc517e eb04de8 dfc517e 4dec0f2 dfc517e eb04de8 dfc517e 4dec0f2 dfc517e 4dec0f2 dfc517e 4dec0f2 dfc517e 4dec0f2 dfc517e 4dec0f2 dfc517e 4dec0f2 dfc517e eb04de8 4dec0f2 dfc517e 4dec0f2 eb04de8 4dec0f2 eb04de8 dfc517e 4dec0f2 eb04de8 dfc517e eb04de8 dfc517e 4dec0f2 e7486fb 882008c dfc517e eb04de8 dfc517e 4dec0f2 eb04de8 dfc517e eb04de8 dfc517e 069bc6a eb04de8 dfc517e eb04de8 dfc517e eb04de8 dfc517e eb04de8 dfc517e eb04de8 37336a7 eb04de8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
import os
import logging
import pandas as pd
import google.generativeai as genai
import gradio as gr
from typing import Dict, List, Any, Tuple
import json
import matplotlib.pyplot as plt
import seaborn as sns
import io
import base64
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DataTools:
"""Tools for data analysis that can be called by the AI"""
def __init__(self, df: pd.DataFrame):
self.df = df
def describe_column(self, column: str) -> dict:
"""Get statistical description of a column"""
if column not in self.df.columns:
return {"error": f"Column {column} not found"}
stats = self.df[column].describe().to_dict()
null_count = self.df[column].isnull().sum()
return {
"statistics": stats,
"null_count": int(null_count),
"dtype": str(self.df[column].dtype)
}
def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str:
"""Create a visualization and return as base64 string"""
try:
plt.figure(figsize=(10, 6))
if plot_type == "histogram":
sns.histplot(data=self.df, x=x)
elif plot_type == "scatter":
sns.scatterplot(data=self.df, x=x, y=y)
elif plot_type == "boxplot":
sns.boxplot(data=self.df, x=x, y=y)
elif plot_type == "bar":
sns.barplot(data=self.df, x=x, y=y)
if title:
plt.title(title)
# Save plot to bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert to base64
return base64.b64encode(buf.read()).decode('utf-8')
except Exception as e:
return f"Error creating visualization: {str(e)}"
def get_correlation(self, columns: List[str]) -> dict:
"""Get correlation between specified columns"""
try:
corr = self.df[columns].corr().to_dict()
return {"correlation_matrix": corr}
except Exception as e:
return {"error": f"Error calculating correlation: {str(e)}"}
class DataAnalyzer:
def __init__(self):
self.model = None
self.api_key = None
self.system_prompt = None
self.df = None
self.tools = None
def configure_api(self, api_key: str):
"""Configure the Gemini API with the provided key"""
try:
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel('gemini-1.5-pro')
self.api_key = api_key
return True
except Exception as e:
logger.error(f"API configuration failed: {str(e)}")
return False
def load_data(self, file) -> Tuple[bool, str]:
"""Load data from uploaded CSV file"""
try:
self.df = pd.read_csv(file.name)
self.tools = DataTools(self.df)
return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
except Exception as e:
logger.error(f"Data loading failed: {str(e)}")
return False, f"Error loading data: {str(e)}"
def get_data_info(self) -> Dict[str, Any]:
"""Get information about the loaded data"""
if self.df is None:
return {"error": "No data loaded"}
info = {
"columns": list(self.df.columns),
"rows": len(self.df),
"sample": self.df.head(5).to_dict('records'),
"dtypes": self.df.dtypes.astype(str).to_dict()
}
return info
def analyze(self, query: str) -> Dict[str, Any]:
"""Analyze data based on user query with structured output"""
if self.model is None:
return {"error": "Please configure API key first"}
if self.df is None:
return {"error": "Please upload a CSV file first"}
data_info = self.get_data_info()
# Combine system prompt with data context and tool instructions
prompt = f"""{self.system_prompt}
Data Information:
- Columns: {data_info['columns']}
- Number of rows: {data_info['rows']}
- Sample data: {json.dumps(data_info['sample'], indent=2)}
Available Tools:
1. describe_column(column: str) - Get statistical description of a column
2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None)
- Create visualizations (types: histogram, scatter, boxplot, bar)
3. get_correlation(columns: List[str]) - Get correlation between columns
User Query: {query}
Please provide a structured analysis in the following JSON format:
{
"answer": "Direct answer to the query",
"tools_used": [
{
"tool": "tool_name",
"parameters": {"param1": "value1"},
"purpose": "Why this tool was used"
}
],
"insights": ["List of key insights"],
"visualizations": ["List of suggested visualizations"],
"recommendations": ["List of recommendations"],
"limitations": ["Any limitations in the analysis"]
}
Important:
- Be specific about which tools to use
- Provide clear reasoning for each tool choice
- Structure the output exactly as shown above
"""
try:
# Get initial response from Gemini
response = self.model.generate_content(prompt)
response_text = response.text
try:
# Parse the response as JSON
structured_response = json.loads(response_text)
# Execute tool calls based on response
results = {"response": structured_response, "tool_outputs": []}
for tool_call in structured_response.get("tools_used", []):
tool_name = tool_call["tool"]
parameters = tool_call["parameters"]
if hasattr(self.tools, tool_name):
tool_method = getattr(self.tools, tool_name)
tool_result = tool_method(**parameters)
results["tool_outputs"].append({
"tool": tool_name,
"parameters": parameters,
"result": tool_result
})
# Format output for Gradio
formatted_output = f"""## Analysis Results
{structured_response['answer']}
### Key Insights
{"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])}
### Visualizations
{"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])}
### Recommendations
{"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])}
### Limitations
{"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])}
---
Tool Outputs:
{"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])}
"""
return formatted_output
except json.JSONDecodeError:
return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}"
except Exception as e:
logger.error(f"Analysis failed: {str(e)}")
return f"Error during analysis: {str(e)}"
def create_interface():
"""Create the Gradio interface"""
analyzer = DataAnalyzer()
def process_inputs(api_key: str, system_prompt: str, file, query: str):
"""Process user inputs and return analysis results"""
if api_key != analyzer.api_key:
if not analyzer.configure_api(api_key):
return "Failed to configure API. Please check your API key."
analyzer.system_prompt = system_prompt
if file is not None:
success, message = analyzer.load_data(file)
if not success:
return message
return analyzer.analyze(query)
# Create Gradio interface
with gr.Blocks(title="Advanced Data Analysis Assistant") as interface:
gr.Markdown("# Advanced Data Analysis Assistant")
gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations")
with gr.Row():
api_key_input = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API key",
type="password"
)
with gr.Row():
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Enter system prompt for the AI",
value="""You are an advanced data analysis expert. Analyze the provided data and answer the query.
Focus on:
1. Clear, structured analysis
2. Statistical insights
3. Appropriate visualizations
4. Actionable recommendations""",
lines=4
)
with gr.Row():
file_input = gr.File(
label="Upload CSV",
file_types=[".csv"]
)
with gr.Row():
query_input = gr.Textbox(
label="Analysis Query",
placeholder="What would you like to know about the data?",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Analyze")
with gr.Row():
output = gr.Markdown(label="Analysis Results")
submit_btn.click(
fn=process_inputs,
inputs=[api_key_input, system_prompt_input, file_input, query_input],
outputs=output
)
return interface
def main():
"""Main application entry point"""
try:
interface = create_interface()
interface.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
raise
if __name__ == "__main__":
main() |