jzou19950715 commited on
Commit
e5bb249
·
verified ·
1 Parent(s): 4ad3262

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -69
app.py CHANGED
@@ -1,104 +1,163 @@
 
 
1
  import os
2
  from dataclasses import dataclass
3
  from typing import Any, Callable, Dict, List, Optional
4
 
5
  import gradio as gr
 
 
6
  import pandas as pd
7
- import torch
8
  from litellm import completion
9
 
10
- # Agent Implementation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @dataclass
12
  class Tool:
13
- """Simplified tool class"""
14
  name: str
15
  description: str
16
  func: Callable
17
 
18
- class MinimalAgent:
19
- """Minimal agent implementation for demo purposes"""
20
 
21
  def __init__(
22
  self,
23
  model_id: str = "gpt-4o-mini",
24
  temperature: float = 0.7,
25
- max_steps: int = 5
26
  ):
27
  self.model_id = model_id
28
  self.temperature = temperature
29
- self.max_steps = max_steps
30
  self.tools: List[Tool] = []
 
31
 
32
  def add_tool(self, name: str, description: str, func: Callable) -> None:
33
  """Add a tool to the agent"""
34
  self.tools.append(Tool(name=name, description=description, func=func))
35
 
36
- def run(self, prompt: str, **kwargs) -> str:
37
- """Run the agent with a prompt"""
38
  messages = [
39
  {"role": "system", "content": self._get_system_prompt()},
40
  {"role": "user", "content": prompt}
41
  ]
42
 
43
  try:
 
44
  response = completion(
45
  model=self.model_id,
46
  messages=messages,
47
  temperature=self.temperature,
48
  )
49
- return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
  return f"Error: {str(e)}"
52
-
53
  def _get_system_prompt(self) -> str:
54
- """Get the system prompt including available tools"""
55
  tools_desc = "\n".join([
56
  f"- {tool.name}: {tool.description}"
57
  for tool in self.tools
58
  ])
59
 
60
- return f"""You are a helpful AI agent that can analyze data and write code.
61
 
62
  Available tools:
63
  {tools_desc}
64
 
65
- Additional capabilities:
66
- - Data analysis with pandas, numpy
67
- - Visualization with matplotlib, seaborn
68
- - Machine learning with sklearn
69
- - Statistical analysis with scipy
70
-
71
- Provide clear explanations and code examples."""
72
 
73
- # Analysis Functions
74
- def analyze_dataframe(df: pd.DataFrame, analysis_type: str) -> str:
75
- """Basic DataFrame analysis"""
76
- if analysis_type == "summary":
77
- return str(df.describe())
78
- elif analysis_type == "info":
79
- buffer = []
80
- df.info(buf=buffer)
81
- return "\n".join(buffer)
82
- return "Unknown analysis type"
83
-
84
- def plot_data(df: pd.DataFrame, plot_type: str) -> None:
85
- """Basic plotting function"""
86
- import matplotlib.pyplot as plt
87
- import seaborn as sns
88
 
89
- if plot_type == "correlation":
90
- plt.figure(figsize=(10, 8))
91
- sns.heatmap(df.corr(), annot=True)
92
- plt.title("Correlation Heatmap")
93
- elif plot_type == "distribution":
94
- df.hist(figsize=(15, 10))
95
- plt.tight_layout()
96
 
97
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
98
  """Process uploaded file into DataFrame"""
99
  if not file:
100
  return None
101
-
102
  try:
103
  if file.name.endswith('.csv'):
104
  return pd.read_csv(file.name)
@@ -118,32 +177,20 @@ def analyze_data(
118
 
119
  if not api_key:
120
  return "Error: Please provide an API key."
121
-
122
  if not file:
123
  return "Error: Please upload a file."
124
-
125
  try:
126
  # Set up environment
127
  os.environ["OPENAI_API_KEY"] = api_key
128
 
129
  # Create agent
130
- agent = MinimalAgent(
131
  model_id="gpt-4o-mini",
132
  temperature=temperature
133
  )
134
 
135
- # Add tools
136
- agent.add_tool(
137
- "analyze_dataframe",
138
- "Analyze DataFrame with various metrics",
139
- analyze_dataframe
140
- )
141
- agent.add_tool(
142
- "plot_data",
143
- "Create various plots from DataFrame",
144
- plot_data
145
- )
146
-
147
  # Process file
148
  df = process_file(file)
149
  if df is None:
@@ -168,13 +215,12 @@ def analyze_data(
168
  User request: {query}
169
 
170
  Please analyze the data and provide:
171
- 1. A clear explanation of your approach
172
- 2. Code for the analysis
173
- 3. Visualizations where relevant
174
- 4. Key insights and findings
175
  """
176
 
177
- return agent.run(prompt)
178
 
179
  except Exception as e:
180
  return f"Error occurred: {str(e)}"
@@ -186,14 +232,14 @@ def create_interface():
186
  gr.Markdown("""
187
  # AI Data Analysis Assistant
188
 
189
- Upload your data file and ask questions in natural language.
190
 
191
  **Features:**
192
  - Data analysis and visualization
193
  - Statistical analysis
194
  - Machine learning capabilities
195
 
196
- **Note**: Requires your own GPT-4 API key.
197
  """)
198
 
199
  with gr.Row():
@@ -232,10 +278,10 @@ def create_interface():
232
 
233
  gr.Examples(
234
  examples=[
235
- [None, "Show key statistics and create visualizations for numeric columns"],
236
- [None, "Find correlations and patterns in the data"],
237
- [None, "Identify outliers and unusual patterns"],
238
- [None, "Create summary visualizations of the main variables"],
239
  ],
240
  inputs=[file, query]
241
  )
 
1
+ import base64
2
+ import io
3
  import os
4
  from dataclasses import dataclass
5
  from typing import Any, Callable, Dict, List, Optional
6
 
7
  import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
  import pandas as pd
11
+ import seaborn as sns
12
  from litellm import completion
13
 
14
+
15
+ # Code Execution Environment
16
+ class CodeEnvironment:
17
+ """Safe environment for executing code with data analysis capabilities"""
18
+
19
+ def __init__(self):
20
+ self.globals = {
21
+ 'pd': pd,
22
+ 'np': np,
23
+ 'plt': plt,
24
+ 'sns': sns,
25
+ }
26
+ self.locals = {}
27
+
28
+ def execute(self, code: str, df: pd.DataFrame = None) -> Dict[str, Any]:
29
+ """Execute code and capture outputs"""
30
+ if df is not None:
31
+ self.globals['df'] = df
32
+
33
+ # Capture output
34
+ output_buffer = io.StringIO()
35
+ result = {'output': '', 'figures': [], 'error': None}
36
+
37
+ try:
38
+ # Execute code
39
+ exec(code, self.globals, self.locals)
40
+
41
+ # Capture figures
42
+ for i in plt.get_fignums():
43
+ fig = plt.figure(i)
44
+ buf = io.BytesIO()
45
+ fig.savefig(buf, format='png')
46
+ buf.seek(0)
47
+ img_str = base64.b64encode(buf.read()).decode()
48
+ result['figures'].append(f"data:image/png;base64,{img_str}")
49
+ plt.close(fig)
50
+
51
+ # Get printed output
52
+ result['output'] = output_buffer.getvalue()
53
+
54
+ except Exception as e:
55
+ result['error'] = str(e)
56
+
57
+ finally:
58
+ output_buffer.close()
59
+
60
+ return result
61
+
62
  @dataclass
63
  class Tool:
64
+ """Tool for data analysis"""
65
  name: str
66
  description: str
67
  func: Callable
68
 
69
+ class AnalysisAgent:
70
+ """Agent that can analyze data and execute code"""
71
 
72
  def __init__(
73
  self,
74
  model_id: str = "gpt-4o-mini",
75
  temperature: float = 0.7,
 
76
  ):
77
  self.model_id = model_id
78
  self.temperature = temperature
 
79
  self.tools: List[Tool] = []
80
+ self.code_env = CodeEnvironment()
81
 
82
  def add_tool(self, name: str, description: str, func: Callable) -> None:
83
  """Add a tool to the agent"""
84
  self.tools.append(Tool(name=name, description=description, func=func))
85
 
86
+ def run(self, prompt: str, df: pd.DataFrame = None) -> str:
87
+ """Run analysis with code execution"""
88
  messages = [
89
  {"role": "system", "content": self._get_system_prompt()},
90
  {"role": "user", "content": prompt}
91
  ]
92
 
93
  try:
94
+ # Get response from model
95
  response = completion(
96
  model=self.model_id,
97
  messages=messages,
98
  temperature=self.temperature,
99
  )
100
+ analysis = response.choices[0].message.content
101
+
102
+ # Extract code blocks
103
+ code_blocks = self._extract_code(analysis)
104
+
105
+ # Execute code and capture results
106
+ results = []
107
+ for code in code_blocks:
108
+ result = self.code_env.execute(code, df)
109
+ if result['error']:
110
+ results.append(f"Error executing code: {result['error']}")
111
+ else:
112
+ # Add output and figures
113
+ if result['output']:
114
+ results.append(result['output'])
115
+ for fig in result['figures']:
116
+ results.append(f"![Figure]({fig})")
117
+
118
+ # Combine analysis and results
119
+ return analysis + "\n\n" + "\n".join(results)
120
+
121
  except Exception as e:
122
  return f"Error: {str(e)}"
123
+
124
  def _get_system_prompt(self) -> str:
125
+ """Get system prompt with tools and capabilities"""
126
  tools_desc = "\n".join([
127
  f"- {tool.name}: {tool.description}"
128
  for tool in self.tools
129
  ])
130
 
131
+ return f"""You are a data analysis assistant.
132
 
133
  Available tools:
134
  {tools_desc}
135
 
136
+ Capabilities:
137
+ - Data analysis (pandas, numpy)
138
+ - Visualization (matplotlib, seaborn)
139
+ - Statistical analysis (scipy)
140
+ - Machine learning (sklearn)
 
 
141
 
142
+ When writing code:
143
+ - Use markdown code blocks
144
+ - Create clear visualizations
145
+ - Include explanations
146
+ - Handle errors gracefully
147
+ """
 
 
 
 
 
 
 
 
 
148
 
149
+ @staticmethod
150
+ def _extract_code(text: str) -> List[str]:
151
+ """Extract Python code blocks from markdown"""
152
+ import re
153
+ pattern = r'```python\n(.*?)```'
154
+ return re.findall(pattern, text, re.DOTALL)
 
155
 
156
  def process_file(file: gr.File) -> Optional[pd.DataFrame]:
157
  """Process uploaded file into DataFrame"""
158
  if not file:
159
  return None
160
+
161
  try:
162
  if file.name.endswith('.csv'):
163
  return pd.read_csv(file.name)
 
177
 
178
  if not api_key:
179
  return "Error: Please provide an API key."
180
+
181
  if not file:
182
  return "Error: Please upload a file."
183
+
184
  try:
185
  # Set up environment
186
  os.environ["OPENAI_API_KEY"] = api_key
187
 
188
  # Create agent
189
+ agent = AnalysisAgent(
190
  model_id="gpt-4o-mini",
191
  temperature=temperature
192
  )
193
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Process file
195
  df = process_file(file)
196
  if df is None:
 
215
  User request: {query}
216
 
217
  Please analyze the data and provide:
218
+ 1. Clear explanation of approach
219
+ 2. Code with visualizations
220
+ 3. Key insights and findings
 
221
  """
222
 
223
+ return agent.run(prompt, df=df)
224
 
225
  except Exception as e:
226
  return f"Error occurred: {str(e)}"
 
232
  gr.Markdown("""
233
  # AI Data Analysis Assistant
234
 
235
+ Upload your data file and get AI-powered analysis with visualizations.
236
 
237
  **Features:**
238
  - Data analysis and visualization
239
  - Statistical analysis
240
  - Machine learning capabilities
241
 
242
+ **Note**: Requires your own OpenAi API key.
243
  """)
244
 
245
  with gr.Row():
 
278
 
279
  gr.Examples(
280
  examples=[
281
+ [None, "Show the distribution of values and key statistics"],
282
+ [None, "Create a correlation analysis with heatmap"],
283
+ [None, "Identify and visualize any outliers in the data"],
284
+ [None, "Generate summary plots for the main variables"],
285
  ],
286
  inputs=[file, query]
287
  )