Update app.py
Browse files
app.py
CHANGED
@@ -5,101 +5,78 @@ import google.generativeai as genai
|
|
5 |
import gradio as gr
|
6 |
from typing import Dict, List, Any, Tuple
|
7 |
import json
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
class DataAnalyzer:
|
14 |
def __init__(self):
|
15 |
self.model = None
|
16 |
self.api_key = None
|
17 |
self.system_prompt = None
|
18 |
self.df = None
|
|
|
19 |
|
20 |
def configure_api(self, api_key: str):
|
21 |
-
|
22 |
-
response = self.model.generate_content(prompt)
|
23 |
-
return response.text
|
24 |
-
except Exception as e:
|
25 |
-
logger.error(f"Analysis failed: {str(e)}")
|
26 |
-
return f"Analysis failed: {str(e)}"
|
27 |
-
|
28 |
-
def create_interface():
|
29 |
-
analyzer = DataAnalyzer()
|
30 |
-
|
31 |
-
def process_inputs(api_key: str, system_prompt: str, file, query: str):
|
32 |
-
"""Process user inputs and return analysis results"""
|
33 |
-
# Configure API
|
34 |
-
if api_key != analyzer.api_key:
|
35 |
-
if not analyzer.configure_api(api_key):
|
36 |
-
return "Failed to configure API. Please check your API key."
|
37 |
-
|
38 |
-
# Update system prompt
|
39 |
-
analyzer.system_prompt = system_prompt
|
40 |
-
|
41 |
-
# Load data if new file provided
|
42 |
-
if file is not None:
|
43 |
-
success, message = analyzer.load_data(file)
|
44 |
-
if not success:
|
45 |
-
return message
|
46 |
-
|
47 |
-
# Run analysis
|
48 |
-
return analyzer.analyze(query)
|
49 |
-
|
50 |
-
# Create Gradio interface
|
51 |
-
with gr.Blocks(title="Data Analysis Assistant") as interface:
|
52 |
-
gr.Markdown("# Data Analysis Assistant")
|
53 |
-
gr.Markdown("Upload your CSV file and get AI-powered analysis")
|
54 |
-
|
55 |
-
with gr.Row():
|
56 |
-
api_key_input = gr.Textbox(
|
57 |
-
label="Gemini API Key",
|
58 |
-
placeholder="Enter your Gemini API key",
|
59 |
-
type="password"
|
60 |
-
)
|
61 |
-
|
62 |
-
with gr.Row():
|
63 |
-
system_prompt_input = gr.Textbox(
|
64 |
-
label="System Prompt",
|
65 |
-
placeholder="Enter system prompt for the AI",
|
66 |
-
value="You are a data analysis expert. Analyze the provided data and answer the user's query.",
|
67 |
-
lines=3
|
68 |
-
)
|
69 |
-
|
70 |
-
with gr.Row():
|
71 |
-
file_input = gr.File(
|
72 |
-
label="Upload CSV",
|
73 |
-
file_types=[".csv"]
|
74 |
-
)
|
75 |
-
|
76 |
-
with gr.Row():
|
77 |
-
query_input = gr.Textbox(
|
78 |
-
label="Analysis Query",
|
79 |
-
placeholder="What would you like to know about the data?",
|
80 |
-
lines=2
|
81 |
-
)
|
82 |
-
|
83 |
-
with gr.Row():
|
84 |
-
submit_btn = gr.Button("Analyze")
|
85 |
-
|
86 |
-
with gr.Row():
|
87 |
-
output = gr.Markdown(label="Analysis Results")
|
88 |
-
|
89 |
-
submit_btn.click(
|
90 |
-
fn=process_inputs,
|
91 |
-
inputs=[api_key_input, system_prompt_input, file_input, query_input],
|
92 |
-
outputs=output
|
93 |
-
)
|
94 |
-
|
95 |
-
return interface
|
96 |
-
|
97 |
-
def main():
|
98 |
-
interface = create_interface()
|
99 |
-
interface.launch()
|
100 |
-
|
101 |
-
if __name__ == "__main__":
|
102 |
-
main()Configure the Gemini API with the provided key"""
|
103 |
try:
|
104 |
genai.configure(api_key=api_key)
|
105 |
self.model = genai.GenerativeModel('gemini-1.5-pro')
|
@@ -113,6 +90,7 @@ if __name__ == "__main__":
|
|
113 |
"""Load data from uploaded CSV file"""
|
114 |
try:
|
115 |
self.df = pd.read_csv(file.name)
|
|
|
116 |
return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
|
117 |
except Exception as e:
|
118 |
logger.error(f"Data loading failed: {str(e)}")
|
@@ -131,16 +109,16 @@ if __name__ == "__main__":
|
|
131 |
}
|
132 |
return info
|
133 |
|
134 |
-
def analyze(self, query: str) -> str:
|
135 |
-
"""Analyze data based on user query"""
|
136 |
if self.model is None:
|
137 |
-
return "Please configure API key first"
|
138 |
if self.df is None:
|
139 |
-
return "Please upload a CSV file first"
|
140 |
|
141 |
data_info = self.get_data_info()
|
142 |
|
143 |
-
# Combine system prompt with data context
|
144 |
prompt = f"""{self.system_prompt}
|
145 |
|
146 |
Data Information:
|
@@ -148,52 +126,89 @@ Data Information:
|
|
148 |
- Number of rows: {data_info['rows']}
|
149 |
- Sample data: {json.dumps(data_info['sample'], indent=2)}
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
|
154 |
-
1. A clear explanation of your findings
|
155 |
-
2. Key statistics relevant to the query
|
156 |
-
3. If appropriate, suggest visualizations that would help understand the data better
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
-
|
164 |
-
-
|
165 |
-
-
|
166 |
-
-
|
167 |
-
- Basic error checking
|
168 |
"""
|
169 |
try:
|
170 |
-
#
|
171 |
response = self.model.generate_content(prompt)
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
except Exception as e:
|
187 |
logger.error(f"Analysis failed: {str(e)}")
|
188 |
-
|
189 |
-
"## Error During Analysis\n\n"
|
190 |
-
f"The analysis failed with error: {str(e)}\n\n"
|
191 |
-
"Please try:\n"
|
192 |
-
"1. Checking your API key\n"
|
193 |
-
"2. Simplifying your query\n"
|
194 |
-
"3. Ensuring your data is properly formatted"
|
195 |
-
)
|
196 |
-
return error_message
|
197 |
|
198 |
def create_interface():
|
199 |
"""Create the Gradio interface"""
|
@@ -201,27 +216,23 @@ def create_interface():
|
|
201 |
|
202 |
def process_inputs(api_key: str, system_prompt: str, file, query: str):
|
203 |
"""Process user inputs and return analysis results"""
|
204 |
-
# Configure API
|
205 |
if api_key != analyzer.api_key:
|
206 |
if not analyzer.configure_api(api_key):
|
207 |
return "Failed to configure API. Please check your API key."
|
208 |
|
209 |
-
# Update system prompt
|
210 |
analyzer.system_prompt = system_prompt
|
211 |
|
212 |
-
# Load data if new file provided
|
213 |
if file is not None:
|
214 |
success, message = analyzer.load_data(file)
|
215 |
if not success:
|
216 |
return message
|
217 |
|
218 |
-
# Run analysis
|
219 |
return analyzer.analyze(query)
|
220 |
|
221 |
# Create Gradio interface
|
222 |
-
with gr.Blocks(title="Data Analysis Assistant") as interface:
|
223 |
-
gr.Markdown("# Data Analysis Assistant")
|
224 |
-
gr.Markdown("Upload your CSV file and get AI-powered analysis")
|
225 |
|
226 |
with gr.Row():
|
227 |
api_key_input = gr.Textbox(
|
@@ -234,8 +245,13 @@ def create_interface():
|
|
234 |
system_prompt_input = gr.Textbox(
|
235 |
label="System Prompt",
|
236 |
placeholder="Enter system prompt for the AI",
|
237 |
-
value="You are
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
239 |
)
|
240 |
|
241 |
with gr.Row():
|
|
|
5 |
import gradio as gr
|
6 |
from typing import Dict, List, Any, Tuple
|
7 |
import json
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
import io
|
11 |
+
import base64
|
12 |
|
13 |
# Configure logging
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
+
class DataTools:
|
18 |
+
"""Tools for data analysis that can be called by the AI"""
|
19 |
+
def __init__(self, df: pd.DataFrame):
|
20 |
+
self.df = df
|
21 |
+
|
22 |
+
def describe_column(self, column: str) -> dict:
|
23 |
+
"""Get statistical description of a column"""
|
24 |
+
if column not in self.df.columns:
|
25 |
+
return {"error": f"Column {column} not found"}
|
26 |
+
|
27 |
+
stats = self.df[column].describe().to_dict()
|
28 |
+
null_count = self.df[column].isnull().sum()
|
29 |
+
return {
|
30 |
+
"statistics": stats,
|
31 |
+
"null_count": int(null_count),
|
32 |
+
"dtype": str(self.df[column].dtype)
|
33 |
+
}
|
34 |
+
|
35 |
+
def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str:
|
36 |
+
"""Create a visualization and return as base64 string"""
|
37 |
+
try:
|
38 |
+
plt.figure(figsize=(10, 6))
|
39 |
+
if plot_type == "histogram":
|
40 |
+
sns.histplot(data=self.df, x=x)
|
41 |
+
elif plot_type == "scatter":
|
42 |
+
sns.scatterplot(data=self.df, x=x, y=y)
|
43 |
+
elif plot_type == "boxplot":
|
44 |
+
sns.boxplot(data=self.df, x=x, y=y)
|
45 |
+
elif plot_type == "bar":
|
46 |
+
sns.barplot(data=self.df, x=x, y=y)
|
47 |
+
|
48 |
+
if title:
|
49 |
+
plt.title(title)
|
50 |
+
|
51 |
+
# Save plot to bytes buffer
|
52 |
+
buf = io.BytesIO()
|
53 |
+
plt.savefig(buf, format='png')
|
54 |
+
buf.seek(0)
|
55 |
+
plt.close()
|
56 |
+
|
57 |
+
# Convert to base64
|
58 |
+
return base64.b64encode(buf.read()).decode('utf-8')
|
59 |
+
except Exception as e:
|
60 |
+
return f"Error creating visualization: {str(e)}"
|
61 |
+
|
62 |
+
def get_correlation(self, columns: List[str]) -> dict:
|
63 |
+
"""Get correlation between specified columns"""
|
64 |
+
try:
|
65 |
+
corr = self.df[columns].corr().to_dict()
|
66 |
+
return {"correlation_matrix": corr}
|
67 |
+
except Exception as e:
|
68 |
+
return {"error": f"Error calculating correlation: {str(e)}"}
|
69 |
+
|
70 |
class DataAnalyzer:
|
71 |
def __init__(self):
|
72 |
self.model = None
|
73 |
self.api_key = None
|
74 |
self.system_prompt = None
|
75 |
self.df = None
|
76 |
+
self.tools = None
|
77 |
|
78 |
def configure_api(self, api_key: str):
|
79 |
+
"""Configure the Gemini API with the provided key"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
try:
|
81 |
genai.configure(api_key=api_key)
|
82 |
self.model = genai.GenerativeModel('gemini-1.5-pro')
|
|
|
90 |
"""Load data from uploaded CSV file"""
|
91 |
try:
|
92 |
self.df = pd.read_csv(file.name)
|
93 |
+
self.tools = DataTools(self.df)
|
94 |
return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
|
95 |
except Exception as e:
|
96 |
logger.error(f"Data loading failed: {str(e)}")
|
|
|
109 |
}
|
110 |
return info
|
111 |
|
112 |
+
def analyze(self, query: str) -> Dict[str, Any]:
|
113 |
+
"""Analyze data based on user query with structured output"""
|
114 |
if self.model is None:
|
115 |
+
return {"error": "Please configure API key first"}
|
116 |
if self.df is None:
|
117 |
+
return {"error": "Please upload a CSV file first"}
|
118 |
|
119 |
data_info = self.get_data_info()
|
120 |
|
121 |
+
# Combine system prompt with data context and tool instructions
|
122 |
prompt = f"""{self.system_prompt}
|
123 |
|
124 |
Data Information:
|
|
|
126 |
- Number of rows: {data_info['rows']}
|
127 |
- Sample data: {json.dumps(data_info['sample'], indent=2)}
|
128 |
|
129 |
+
Available Tools:
|
130 |
+
1. describe_column(column: str) - Get statistical description of a column
|
131 |
+
2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None)
|
132 |
+
- Create visualizations (types: histogram, scatter, boxplot, bar)
|
133 |
+
3. get_correlation(columns: List[str]) - Get correlation between columns
|
134 |
|
135 |
+
User Query: {query}
|
|
|
|
|
|
|
136 |
|
137 |
+
Please provide a structured analysis in the following JSON format:
|
138 |
+
{
|
139 |
+
"answer": "Direct answer to the query",
|
140 |
+
"tools_used": [
|
141 |
+
{
|
142 |
+
"tool": "tool_name",
|
143 |
+
"parameters": {"param1": "value1"},
|
144 |
+
"purpose": "Why this tool was used"
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"insights": ["List of key insights"],
|
148 |
+
"visualizations": ["List of suggested visualizations"],
|
149 |
+
"recommendations": ["List of recommendations"],
|
150 |
+
"limitations": ["Any limitations in the analysis"]
|
151 |
+
}
|
152 |
|
153 |
+
Important:
|
154 |
+
- Be specific about which tools to use
|
155 |
+
- Provide clear reasoning for each tool choice
|
156 |
+
- Structure the output exactly as shown above
|
|
|
157 |
"""
|
158 |
try:
|
159 |
+
# Get initial response from Gemini
|
160 |
response = self.model.generate_content(prompt)
|
161 |
+
response_text = response.text
|
162 |
|
163 |
+
try:
|
164 |
+
# Parse the response as JSON
|
165 |
+
structured_response = json.loads(response_text)
|
166 |
+
|
167 |
+
# Execute tool calls based on response
|
168 |
+
results = {"response": structured_response, "tool_outputs": []}
|
169 |
+
|
170 |
+
for tool_call in structured_response.get("tools_used", []):
|
171 |
+
tool_name = tool_call["tool"]
|
172 |
+
parameters = tool_call["parameters"]
|
173 |
+
|
174 |
+
if hasattr(self.tools, tool_name):
|
175 |
+
tool_method = getattr(self.tools, tool_name)
|
176 |
+
tool_result = tool_method(**parameters)
|
177 |
+
results["tool_outputs"].append({
|
178 |
+
"tool": tool_name,
|
179 |
+
"parameters": parameters,
|
180 |
+
"result": tool_result
|
181 |
+
})
|
182 |
+
|
183 |
+
# Format output for Gradio
|
184 |
+
formatted_output = f"""## Analysis Results
|
185 |
+
|
186 |
+
{structured_response['answer']}
|
187 |
+
|
188 |
+
### Key Insights
|
189 |
+
{"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])}
|
190 |
+
|
191 |
+
### Visualizations
|
192 |
+
{"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])}
|
193 |
+
|
194 |
+
### Recommendations
|
195 |
+
{"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])}
|
196 |
+
|
197 |
+
### Limitations
|
198 |
+
{"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])}
|
199 |
+
|
200 |
+
---
|
201 |
+
Tool Outputs:
|
202 |
+
{"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])}
|
203 |
+
"""
|
204 |
+
return formatted_output
|
205 |
+
|
206 |
+
except json.JSONDecodeError:
|
207 |
+
return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}"
|
208 |
|
209 |
except Exception as e:
|
210 |
logger.error(f"Analysis failed: {str(e)}")
|
211 |
+
return f"Error during analysis: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
def create_interface():
|
214 |
"""Create the Gradio interface"""
|
|
|
216 |
|
217 |
def process_inputs(api_key: str, system_prompt: str, file, query: str):
|
218 |
"""Process user inputs and return analysis results"""
|
|
|
219 |
if api_key != analyzer.api_key:
|
220 |
if not analyzer.configure_api(api_key):
|
221 |
return "Failed to configure API. Please check your API key."
|
222 |
|
|
|
223 |
analyzer.system_prompt = system_prompt
|
224 |
|
|
|
225 |
if file is not None:
|
226 |
success, message = analyzer.load_data(file)
|
227 |
if not success:
|
228 |
return message
|
229 |
|
|
|
230 |
return analyzer.analyze(query)
|
231 |
|
232 |
# Create Gradio interface
|
233 |
+
with gr.Blocks(title="Advanced Data Analysis Assistant") as interface:
|
234 |
+
gr.Markdown("# Advanced Data Analysis Assistant")
|
235 |
+
gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations")
|
236 |
|
237 |
with gr.Row():
|
238 |
api_key_input = gr.Textbox(
|
|
|
245 |
system_prompt_input = gr.Textbox(
|
246 |
label="System Prompt",
|
247 |
placeholder="Enter system prompt for the AI",
|
248 |
+
value="""You are an advanced data analysis expert. Analyze the provided data and answer the query.
|
249 |
+
Focus on:
|
250 |
+
1. Clear, structured analysis
|
251 |
+
2. Statistical insights
|
252 |
+
3. Appropriate visualizations
|
253 |
+
4. Actionable recommendations""",
|
254 |
+
lines=4
|
255 |
)
|
256 |
|
257 |
with gr.Row():
|