jzou19950715 commited on
Commit
4dec0f2
·
verified ·
1 Parent(s): dfc517e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -132
app.py CHANGED
@@ -5,101 +5,78 @@ import google.generativeai as genai
5
  import gradio as gr
6
  from typing import Dict, List, Any, Tuple
7
  import json
 
 
 
 
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class DataAnalyzer:
14
  def __init__(self):
15
  self.model = None
16
  self.api_key = None
17
  self.system_prompt = None
18
  self.df = None
 
19
 
20
  def configure_api(self, api_key: str):
21
- try:
22
- response = self.model.generate_content(prompt)
23
- return response.text
24
- except Exception as e:
25
- logger.error(f"Analysis failed: {str(e)}")
26
- return f"Analysis failed: {str(e)}"
27
-
28
- def create_interface():
29
- analyzer = DataAnalyzer()
30
-
31
- def process_inputs(api_key: str, system_prompt: str, file, query: str):
32
- """Process user inputs and return analysis results"""
33
- # Configure API
34
- if api_key != analyzer.api_key:
35
- if not analyzer.configure_api(api_key):
36
- return "Failed to configure API. Please check your API key."
37
-
38
- # Update system prompt
39
- analyzer.system_prompt = system_prompt
40
-
41
- # Load data if new file provided
42
- if file is not None:
43
- success, message = analyzer.load_data(file)
44
- if not success:
45
- return message
46
-
47
- # Run analysis
48
- return analyzer.analyze(query)
49
-
50
- # Create Gradio interface
51
- with gr.Blocks(title="Data Analysis Assistant") as interface:
52
- gr.Markdown("# Data Analysis Assistant")
53
- gr.Markdown("Upload your CSV file and get AI-powered analysis")
54
-
55
- with gr.Row():
56
- api_key_input = gr.Textbox(
57
- label="Gemini API Key",
58
- placeholder="Enter your Gemini API key",
59
- type="password"
60
- )
61
-
62
- with gr.Row():
63
- system_prompt_input = gr.Textbox(
64
- label="System Prompt",
65
- placeholder="Enter system prompt for the AI",
66
- value="You are a data analysis expert. Analyze the provided data and answer the user's query.",
67
- lines=3
68
- )
69
-
70
- with gr.Row():
71
- file_input = gr.File(
72
- label="Upload CSV",
73
- file_types=[".csv"]
74
- )
75
-
76
- with gr.Row():
77
- query_input = gr.Textbox(
78
- label="Analysis Query",
79
- placeholder="What would you like to know about the data?",
80
- lines=2
81
- )
82
-
83
- with gr.Row():
84
- submit_btn = gr.Button("Analyze")
85
-
86
- with gr.Row():
87
- output = gr.Markdown(label="Analysis Results")
88
-
89
- submit_btn.click(
90
- fn=process_inputs,
91
- inputs=[api_key_input, system_prompt_input, file_input, query_input],
92
- outputs=output
93
- )
94
-
95
- return interface
96
-
97
- def main():
98
- interface = create_interface()
99
- interface.launch()
100
-
101
- if __name__ == "__main__":
102
- main()Configure the Gemini API with the provided key"""
103
  try:
104
  genai.configure(api_key=api_key)
105
  self.model = genai.GenerativeModel('gemini-1.5-pro')
@@ -113,6 +90,7 @@ if __name__ == "__main__":
113
  """Load data from uploaded CSV file"""
114
  try:
115
  self.df = pd.read_csv(file.name)
 
116
  return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
117
  except Exception as e:
118
  logger.error(f"Data loading failed: {str(e)}")
@@ -131,16 +109,16 @@ if __name__ == "__main__":
131
  }
132
  return info
133
 
134
- def analyze(self, query: str) -> str:
135
- """Analyze data based on user query"""
136
  if self.model is None:
137
- return "Please configure API key first"
138
  if self.df is None:
139
- return "Please upload a CSV file first"
140
 
141
  data_info = self.get_data_info()
142
 
143
- # Combine system prompt with data context
144
  prompt = f"""{self.system_prompt}
145
 
146
  Data Information:
@@ -148,52 +126,89 @@ Data Information:
148
  - Number of rows: {data_info['rows']}
149
  - Sample data: {json.dumps(data_info['sample'], indent=2)}
150
 
151
- User Query: {query}
 
 
 
 
152
 
153
- Please analyze this data and provide:
154
- 1. A clear explanation of your findings
155
- 2. Key statistics relevant to the query
156
- 3. If appropriate, suggest visualizations that would help understand the data better
157
 
158
- Response Format:
159
- 1. First give a direct answer to the query
160
- 2. Then provide supporting statistics
161
- 3. Finally, suggest any relevant additional insights
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- Remember to handle:
164
- - Missing or null values
165
- - Outliers
166
- - Data type conversions if needed
167
- - Basic error checking
168
  """
169
  try:
170
- # Call Gemini API
171
  response = self.model.generate_content(prompt)
 
172
 
173
- # Extract and format the response
174
- if response.text:
175
- formatted_response = (
176
- "## Analysis Results\n\n"
177
- f"{response.text}\n\n"
178
- "---\n"
179
- "Note: This analysis was generated using the provided data. "
180
- "Please verify any critical insights independently."
181
- )
182
- return formatted_response
183
- else:
184
- return "No analysis could be generated. Please try a different query."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  except Exception as e:
187
  logger.error(f"Analysis failed: {str(e)}")
188
- error_message = (
189
- "## Error During Analysis\n\n"
190
- f"The analysis failed with error: {str(e)}\n\n"
191
- "Please try:\n"
192
- "1. Checking your API key\n"
193
- "2. Simplifying your query\n"
194
- "3. Ensuring your data is properly formatted"
195
- )
196
- return error_message
197
 
198
  def create_interface():
199
  """Create the Gradio interface"""
@@ -201,27 +216,23 @@ def create_interface():
201
 
202
  def process_inputs(api_key: str, system_prompt: str, file, query: str):
203
  """Process user inputs and return analysis results"""
204
- # Configure API
205
  if api_key != analyzer.api_key:
206
  if not analyzer.configure_api(api_key):
207
  return "Failed to configure API. Please check your API key."
208
 
209
- # Update system prompt
210
  analyzer.system_prompt = system_prompt
211
 
212
- # Load data if new file provided
213
  if file is not None:
214
  success, message = analyzer.load_data(file)
215
  if not success:
216
  return message
217
 
218
- # Run analysis
219
  return analyzer.analyze(query)
220
 
221
  # Create Gradio interface
222
- with gr.Blocks(title="Data Analysis Assistant") as interface:
223
- gr.Markdown("# Data Analysis Assistant")
224
- gr.Markdown("Upload your CSV file and get AI-powered analysis")
225
 
226
  with gr.Row():
227
  api_key_input = gr.Textbox(
@@ -234,8 +245,13 @@ def create_interface():
234
  system_prompt_input = gr.Textbox(
235
  label="System Prompt",
236
  placeholder="Enter system prompt for the AI",
237
- value="You are a data analysis expert. Analyze the provided data and answer the user's query.",
238
- lines=3
 
 
 
 
 
239
  )
240
 
241
  with gr.Row():
 
5
  import gradio as gr
6
  from typing import Dict, List, Any, Tuple
7
  import json
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import io
11
+ import base64
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ class DataTools:
18
+ """Tools for data analysis that can be called by the AI"""
19
+ def __init__(self, df: pd.DataFrame):
20
+ self.df = df
21
+
22
+ def describe_column(self, column: str) -> dict:
23
+ """Get statistical description of a column"""
24
+ if column not in self.df.columns:
25
+ return {"error": f"Column {column} not found"}
26
+
27
+ stats = self.df[column].describe().to_dict()
28
+ null_count = self.df[column].isnull().sum()
29
+ return {
30
+ "statistics": stats,
31
+ "null_count": int(null_count),
32
+ "dtype": str(self.df[column].dtype)
33
+ }
34
+
35
+ def create_visualization(self, plot_type: str, x: str, y: str = None, title: str = None) -> str:
36
+ """Create a visualization and return as base64 string"""
37
+ try:
38
+ plt.figure(figsize=(10, 6))
39
+ if plot_type == "histogram":
40
+ sns.histplot(data=self.df, x=x)
41
+ elif plot_type == "scatter":
42
+ sns.scatterplot(data=self.df, x=x, y=y)
43
+ elif plot_type == "boxplot":
44
+ sns.boxplot(data=self.df, x=x, y=y)
45
+ elif plot_type == "bar":
46
+ sns.barplot(data=self.df, x=x, y=y)
47
+
48
+ if title:
49
+ plt.title(title)
50
+
51
+ # Save plot to bytes buffer
52
+ buf = io.BytesIO()
53
+ plt.savefig(buf, format='png')
54
+ buf.seek(0)
55
+ plt.close()
56
+
57
+ # Convert to base64
58
+ return base64.b64encode(buf.read()).decode('utf-8')
59
+ except Exception as e:
60
+ return f"Error creating visualization: {str(e)}"
61
+
62
+ def get_correlation(self, columns: List[str]) -> dict:
63
+ """Get correlation between specified columns"""
64
+ try:
65
+ corr = self.df[columns].corr().to_dict()
66
+ return {"correlation_matrix": corr}
67
+ except Exception as e:
68
+ return {"error": f"Error calculating correlation: {str(e)}"}
69
+
70
  class DataAnalyzer:
71
  def __init__(self):
72
  self.model = None
73
  self.api_key = None
74
  self.system_prompt = None
75
  self.df = None
76
+ self.tools = None
77
 
78
  def configure_api(self, api_key: str):
79
+ """Configure the Gemini API with the provided key"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  try:
81
  genai.configure(api_key=api_key)
82
  self.model = genai.GenerativeModel('gemini-1.5-pro')
 
90
  """Load data from uploaded CSV file"""
91
  try:
92
  self.df = pd.read_csv(file.name)
93
+ self.tools = DataTools(self.df)
94
  return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
95
  except Exception as e:
96
  logger.error(f"Data loading failed: {str(e)}")
 
109
  }
110
  return info
111
 
112
+ def analyze(self, query: str) -> Dict[str, Any]:
113
+ """Analyze data based on user query with structured output"""
114
  if self.model is None:
115
+ return {"error": "Please configure API key first"}
116
  if self.df is None:
117
+ return {"error": "Please upload a CSV file first"}
118
 
119
  data_info = self.get_data_info()
120
 
121
+ # Combine system prompt with data context and tool instructions
122
  prompt = f"""{self.system_prompt}
123
 
124
  Data Information:
 
126
  - Number of rows: {data_info['rows']}
127
  - Sample data: {json.dumps(data_info['sample'], indent=2)}
128
 
129
+ Available Tools:
130
+ 1. describe_column(column: str) - Get statistical description of a column
131
+ 2. create_visualization(plot_type: str, x: str, y: str = None, title: str = None)
132
+ - Create visualizations (types: histogram, scatter, boxplot, bar)
133
+ 3. get_correlation(columns: List[str]) - Get correlation between columns
134
 
135
+ User Query: {query}
 
 
 
136
 
137
+ Please provide a structured analysis in the following JSON format:
138
+ {
139
+ "answer": "Direct answer to the query",
140
+ "tools_used": [
141
+ {
142
+ "tool": "tool_name",
143
+ "parameters": {"param1": "value1"},
144
+ "purpose": "Why this tool was used"
145
+ }
146
+ ],
147
+ "insights": ["List of key insights"],
148
+ "visualizations": ["List of suggested visualizations"],
149
+ "recommendations": ["List of recommendations"],
150
+ "limitations": ["Any limitations in the analysis"]
151
+ }
152
 
153
+ Important:
154
+ - Be specific about which tools to use
155
+ - Provide clear reasoning for each tool choice
156
+ - Structure the output exactly as shown above
 
157
  """
158
  try:
159
+ # Get initial response from Gemini
160
  response = self.model.generate_content(prompt)
161
+ response_text = response.text
162
 
163
+ try:
164
+ # Parse the response as JSON
165
+ structured_response = json.loads(response_text)
166
+
167
+ # Execute tool calls based on response
168
+ results = {"response": structured_response, "tool_outputs": []}
169
+
170
+ for tool_call in structured_response.get("tools_used", []):
171
+ tool_name = tool_call["tool"]
172
+ parameters = tool_call["parameters"]
173
+
174
+ if hasattr(self.tools, tool_name):
175
+ tool_method = getattr(self.tools, tool_name)
176
+ tool_result = tool_method(**parameters)
177
+ results["tool_outputs"].append({
178
+ "tool": tool_name,
179
+ "parameters": parameters,
180
+ "result": tool_result
181
+ })
182
+
183
+ # Format output for Gradio
184
+ formatted_output = f"""## Analysis Results
185
+
186
+ {structured_response['answer']}
187
+
188
+ ### Key Insights
189
+ {"".join(['- ' + insight + '\\n' for insight in structured_response['insights']])}
190
+
191
+ ### Visualizations
192
+ {"".join(['- ' + viz + '\\n' for viz in structured_response['visualizations']])}
193
+
194
+ ### Recommendations
195
+ {"".join(['- ' + rec + '\\n' for rec in structured_response['recommendations']])}
196
+
197
+ ### Limitations
198
+ {"".join(['- ' + lim + '\\n' for lim in structured_response['limitations']])}
199
+
200
+ ---
201
+ Tool Outputs:
202
+ {"".join([f'\\n**{out["tool"]}**:\\n```json\\n{json.dumps(out["result"], indent=2)}\\n```' for out in results['tool_outputs']])}
203
+ """
204
+ return formatted_output
205
+
206
+ except json.JSONDecodeError:
207
+ return f"Error: Could not parse structured response\\n\\nRaw response:\\n{response_text}"
208
 
209
  except Exception as e:
210
  logger.error(f"Analysis failed: {str(e)}")
211
+ return f"Error during analysis: {str(e)}"
 
 
 
 
 
 
 
 
212
 
213
  def create_interface():
214
  """Create the Gradio interface"""
 
216
 
217
  def process_inputs(api_key: str, system_prompt: str, file, query: str):
218
  """Process user inputs and return analysis results"""
 
219
  if api_key != analyzer.api_key:
220
  if not analyzer.configure_api(api_key):
221
  return "Failed to configure API. Please check your API key."
222
 
 
223
  analyzer.system_prompt = system_prompt
224
 
 
225
  if file is not None:
226
  success, message = analyzer.load_data(file)
227
  if not success:
228
  return message
229
 
 
230
  return analyzer.analyze(query)
231
 
232
  # Create Gradio interface
233
+ with gr.Blocks(title="Advanced Data Analysis Assistant") as interface:
234
+ gr.Markdown("# Advanced Data Analysis Assistant")
235
+ gr.Markdown("Upload your CSV file and get AI-powered analysis with visualizations")
236
 
237
  with gr.Row():
238
  api_key_input = gr.Textbox(
 
245
  system_prompt_input = gr.Textbox(
246
  label="System Prompt",
247
  placeholder="Enter system prompt for the AI",
248
+ value="""You are an advanced data analysis expert. Analyze the provided data and answer the query.
249
+ Focus on:
250
+ 1. Clear, structured analysis
251
+ 2. Statistical insights
252
+ 3. Appropriate visualizations
253
+ 4. Actionable recommendations""",
254
+ lines=4
255
  )
256
 
257
  with gr.Row():