jzou19950715 commited on
Commit
a5d47f2
·
verified ·
1 Parent(s): c94ceb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -196
app.py CHANGED
@@ -8,254 +8,168 @@ import tempfile
8
  import sys
9
  from io import StringIO
10
  import matplotlib.pyplot as plt
11
- import re
12
  from pathlib import Path
13
  import importlib
 
14
 
15
- class CodeExecutionEnvironment:
 
16
  def __init__(self):
17
  self.globals_dict = {}
18
- self.figures_dir = "temp_figures"
19
- os.makedirs(self.figures_dir, exist_ok=True)
 
20
 
21
- def install_package(self, package_name):
22
- """Dynamically install a Python package"""
 
 
 
 
 
 
 
 
 
23
  try:
24
- subprocess.check_call([sys.executable, "-m", "pip", "install", package_name],
25
- stdout=subprocess.DEVNULL,
26
- stderr=subprocess.DEVNULL)
 
 
 
27
  return True
28
- except:
 
29
  return False
30
 
31
- def execute_code(self, code):
32
- """Execute code in the managed environment"""
33
- # Redirect stdout to capture prints
34
- old_stdout = sys.stdout
35
- redirected_output = StringIO()
36
- sys.stdout = redirected_output
37
-
38
  try:
39
- # First pass: collect and install required imports
40
- import_lines = [line for line in code.split('\n') if 'import' in line]
41
- for line in import_lines:
42
- parts = line.split()
43
- if parts[0] == 'import':
44
- package = parts[1].split('.')[0]
45
- if package not in sys.modules:
46
- self.install_package(package)
47
- try:
48
- self.globals_dict[package] = importlib.import_module(package)
49
- except:
50
- pass
51
- elif parts[0] == 'from':
52
- package = parts[1].split('.')[0]
53
- if package not in sys.modules:
54
- self.install_package(package)
55
-
56
- # Add common data science packages to globals
57
- if 'pd' not in self.globals_dict:
58
- self.globals_dict['pd'] = pd
59
- if 'plt' not in self.globals_dict:
60
- self.globals_dict['plt'] = plt
61
- if 'np' not in self.globals_dict:
62
- import numpy as np
63
- self.globals_dict['np'] = np
64
 
 
 
 
 
 
 
 
 
 
 
65
  # Execute the code
66
  exec(code, self.globals_dict)
67
- output = redirected_output.getvalue()
68
-
69
- # Handle different types of figures
70
- figures = []
71
 
72
- # Handle Matplotlib figures
73
- if plt.get_figs():
 
 
 
 
74
  for i, fig in enumerate(plt.get_figs()):
75
- fig_path = os.path.join(self.figures_dir, f"mpl_figure_{len(figures)}.png")
76
  fig.savefig(fig_path)
77
  figures.append(fig_path)
78
  plt.close('all')
79
 
80
- # Handle Plotly figures
81
- if 'fig' in self.globals_dict and 'plotly' in str(type(self.globals_dict['fig'])):
82
  fig = self.globals_dict['fig']
83
- fig_path = os.path.join(self.figures_dir, f"plotly_figure_{len(figures)}.html")
84
- fig.write_html(fig_path)
85
- # Also save as image for gallery display
86
- img_path = os.path.join(self.figures_dir, f"plotly_figure_{len(figures)}.png")
87
- fig.write_image(img_path)
88
- figures.append(img_path)
89
-
90
- return True, output, figures
91
-
 
 
92
  except Exception as e:
93
  return False, str(e), []
94
  finally:
95
- sys.stdout = old_stdout
96
-
97
- def extract_and_execute_code(self, text):
98
- """Extract code blocks from markdown and execute them"""
99
- code_blocks = re.findall(r'```python(.*?)```', text, re.DOTALL)
100
- if not code_blocks:
101
- return text, None, []
102
-
103
- all_outputs = []
104
- all_figures = []
105
-
106
- for code in code_blocks:
107
- success, output, figures = self.execute_code(code.strip())
108
- if success:
109
- all_outputs.append(output)
110
- all_figures.extend(figures)
111
- else:
112
- all_outputs.append(f"Error: {output}")
113
-
114
- # Replace code blocks with code + output
115
- modified_text = text
116
- for i, (code, output) in enumerate(zip(code_blocks, all_outputs)):
117
- code_section = f"```python{code}```"
118
- output_section = f"\nOutput:\n```\n{output}\n```"
119
- modified_text = modified_text.replace(code_section, code_section + output_section)
120
-
121
- return modified_text, "\n".join(all_outputs), all_figures
122
-
123
- def query_deepseek(prompt: str, api_key: str, system_prompt: str = None):
124
- """Send a prompt to DeepSeek API"""
125
- headers = {
126
- "Content-Type": "application/json",
127
- "Authorization": f"Bearer {api_key}"
128
- }
129
-
130
- messages = []
131
- if system_prompt:
132
- messages.append({"role": "system", "content": system_prompt})
133
- messages.append({"role": "user", "content": prompt})
134
-
135
- payload = {
136
- "model": "deepseek-reasoner",
137
- "messages": messages,
138
- "stream": False
139
- }
140
 
141
- try:
142
- response = requests.post("https://api.deepseek.com/chat/completions",
143
- headers=headers,
144
- json=payload)
145
- response.raise_for_status()
146
- return response.json()["choices"][0]["message"]["content"]
147
- except Exception as e:
148
- return f"API Error: {str(e)}"
149
-
150
- class ChatAndCodeInterface:
151
- def __init__(self):
152
- self.env = CodeExecutionEnvironment()
153
- self.current_df = None
154
 
155
- def process_message(self, message, history, csv_file, api_key, system_prompt):
156
- """Process a chat message with code execution capabilities"""
157
  if not api_key:
158
- return history + [[message, "Please provide your DeepSeek API key first."]], None
159
 
160
- # Update dataframe if new CSV uploaded
161
- if csv_file and (self.current_df is None or csv_file.name != getattr(self.current_df, '_filename', None)):
162
- self.current_df = pd.read_csv(csv_file.name)
163
- self.current_df._filename = csv_file.name
164
- self.env.globals_dict['df'] = self.current_df
165
 
166
- # Build context
167
- context = ""
168
- if self.current_df is not None:
169
- context = (f"\nContext: Working with CSV file containing columns: {self.current_df.columns.tolist()}\n"
170
- f"First few rows: {self.current_df.head(3).to_dict()}\n"
171
- f"The dataframe is available as 'df' in the code environment.\n")
172
 
173
- # Get AI response
174
- full_prompt = (
175
- context +
176
- "The user might ask you to analyze data or generate visualizations. "
177
- "When you write code, wrap it in ```python``` blocks. "
178
- "You can use any Python library - they will be automatically installed. "
179
- "For interactive maps, use plotly.express and ensure you install required dependencies. "
180
- "\nUser message: " + message
181
- )
182
-
183
- response = query_deepseek(full_prompt, api_key, system_prompt)
184
 
185
- # Execute any code in the response
186
- modified_response, outputs, figures = self.env.extract_and_execute_code(response)
 
 
 
187
 
188
- # Update chat history
189
- history = history + [[message, modified_response]]
 
 
 
 
 
190
 
191
- return history, figures
192
-
193
- def create_interface():
194
- """Create the unified chat and code execution interface"""
195
- interface = ChatAndCodeInterface()
196
 
 
197
  with gr.Blocks() as demo:
198
- gr.Markdown("# AI Data Analysis Assistant with Code Execution")
199
 
200
  with gr.Row():
201
  with gr.Column(scale=1):
202
  api_key = gr.Textbox(
203
- label="DeepSeek API Key",
204
- type="password",
205
- placeholder="Enter your API key"
206
- )
207
- system_prompt = gr.Textbox(
208
- label="System Prompt",
209
- value="You are an AI assistant specialized in data analysis and Python programming. When asked to analyze data or create visualizations, you provide executable Python code.",
210
- lines=3
211
  )
212
  csv_file = gr.File(
213
- label="Upload CSV File",
214
  file_types=[".csv"]
215
  )
216
-
217
  with gr.Column(scale=3):
218
  chatbot = gr.Chatbot(height=500)
219
- gallery = gr.Gallery(
220
- label="Generated Visualizations",
221
- columns=2,
222
- height=400,
223
- object_fit="contain"
224
- )
225
 
226
  with gr.Row():
227
  msg = gr.Textbox(
228
- label="Your Message",
229
- placeholder="Ask me to analyze your data or create visualizations...",
230
- scale=9
231
  )
232
- clear = gr.Button("Clear", scale=1)
233
-
234
  msg.submit(
235
- interface.process_message,
236
- [msg, chatbot, csv_file, api_key, system_prompt],
237
  [chatbot, gallery]
238
  )
239
-
240
- clear.click(lambda: ([], []), None, [chatbot, gallery], queue=False)
241
-
242
- gr.Markdown("""
243
- ## How to Use
244
- 1. Enter your DeepSeek API key
245
- 2. Upload a CSV file for analysis
246
- 3. Chat naturally about your data analysis needs
247
-
248
- Example prompts:
249
- - "Create a histogram of the numerical columns"
250
- - "Generate an interactive map of the locations"
251
- - "Show the correlation between variables"
252
- - "Create a summary dashboard of key metrics"
253
-
254
- The assistant will:
255
- - Generate and execute Python code automatically
256
- - Handle both static and interactive visualizations
257
- - Show code, output, and visualizations in one place
258
- """)
259
 
260
  return demo
261
 
 
8
  import sys
9
  from io import StringIO
10
  import matplotlib.pyplot as plt
 
11
  from pathlib import Path
12
  import importlib
13
+ import ast
14
 
15
+ class AICodeEnvironment:
16
+ """Environment for AI to execute code safely"""
17
  def __init__(self):
18
  self.globals_dict = {}
19
+ self.temp_dir = "temp_outputs"
20
+ os.makedirs(self.temp_dir, exist_ok=True)
21
+ self.setup_base_environment()
22
 
23
+ def setup_base_environment(self):
24
+ """Set up the base environment with commonly used packages"""
25
+ self.globals_dict.update({
26
+ 'pd': pd,
27
+ 'plt': plt,
28
+ '__builtins__': __builtins__,
29
+ 'print': print
30
+ })
31
+
32
+ def dynamic_import(self, package_name):
33
+ """Dynamically import packages as needed by AI"""
34
  try:
35
+ # Install package if not present
36
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", package_name])
37
+
38
+ # Import the package
39
+ module = importlib.import_module(package_name)
40
+ self.globals_dict[package_name] = module
41
  return True
42
+ except Exception as e:
43
+ print(f"Failed to import {package_name}: {str(e)}")
44
  return False
45
 
46
+ def handle_imports(self, code):
47
+ """Extract and handle all imports in the code"""
 
 
 
 
 
48
  try:
49
+ tree = ast.parse(code)
50
+ for node in ast.walk(tree):
51
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
52
+ for name in node.names:
53
+ package = name.name.split('.')[0]
54
+ if package not in self.globals_dict:
55
+ self.dynamic_import(package)
56
+ return True
57
+ except Exception as e:
58
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def execute_code(self, code):
61
+ """Execute code and capture all outputs"""
62
+ # Create temporary stdout to capture prints
63
+ output_buffer = StringIO()
64
+ sys.stdout = output_buffer
65
+
66
+ try:
67
+ # Handle imports first
68
+ self.handle_imports(code)
69
+
70
  # Execute the code
71
  exec(code, self.globals_dict)
 
 
 
 
72
 
73
+ # Capture terminal output
74
+ text_output = output_buffer.getvalue()
75
+
76
+ # Handle figures
77
+ figures = []
78
+ if 'plt' in self.globals_dict and plt.get_figs():
79
  for i, fig in enumerate(plt.get_figs()):
80
+ fig_path = os.path.join(self.temp_dir, f"figure_{len(figures)}.png")
81
  fig.savefig(fig_path)
82
  figures.append(fig_path)
83
  plt.close('all')
84
 
85
+ # Check for other visualization libraries
86
+ if 'fig' in self.globals_dict:
87
  fig = self.globals_dict['fig']
88
+ # Handle Plotly figures
89
+ if 'plotly.graph_objs' in str(type(fig)):
90
+ fig_path = os.path.join(self.temp_dir, f"figure_{len(figures)}.html")
91
+ fig.write_html(fig_path)
92
+ # Also save static image
93
+ img_path = os.path.join(self.temp_dir, f"figure_{len(figures)}.png")
94
+ fig.write_image(img_path)
95
+ figures.append(img_path)
96
+
97
+ return True, text_output, figures
98
+
99
  except Exception as e:
100
  return False, str(e), []
101
  finally:
102
+ sys.stdout = sys.__stdout__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ def create_interface():
105
+ """Create the interface for AI code execution"""
106
+ env = AICodeEnvironment()
 
 
 
 
 
 
 
 
 
 
107
 
108
+ def process_message(message, history, csv_file, api_key):
109
+ """Process message and execute any code blocks"""
110
  if not api_key:
111
+ return history + [[message, "Please provide your API key."]], None
112
 
113
+ # Update environment with dataframe if CSV uploaded
114
+ if csv_file:
115
+ env.globals_dict['df'] = pd.read_csv(csv_file.name)
 
 
116
 
117
+ # Get response from AI (example structure)
118
+ response = query_ai(message, api_key)
 
 
 
 
119
 
120
+ # Extract and execute code blocks
121
+ code_blocks = response.split("```python")
122
+ outputs = []
123
+ figures = []
 
 
 
 
 
 
 
124
 
125
+ for block in code_blocks[1:]: # Skip first split as it's before any code block
126
+ code = block.split("```")[0].strip()
127
+ success, output, new_figures = env.execute_code(code)
128
+ outputs.append(output)
129
+ figures.extend(new_figures)
130
 
131
+ # Format response with outputs
132
+ modified_response = response
133
+ for i, output in enumerate(outputs):
134
+ modified_response = modified_response.replace(
135
+ f"```python{code_blocks[i+1].split('```')[0]}```",
136
+ f"```python{code_blocks[i+1].split('```')[0]}```\nOutput:\n{output}"
137
+ )
138
 
139
+ return history + [[message, modified_response]], figures
 
 
 
 
140
 
141
+ # Create Gradio interface
142
  with gr.Blocks() as demo:
143
+ gr.Markdown("# AI Code Execution Environment")
144
 
145
  with gr.Row():
146
  with gr.Column(scale=1):
147
  api_key = gr.Textbox(
148
+ label="API Key",
149
+ type="password"
 
 
 
 
 
 
150
  )
151
  csv_file = gr.File(
152
+ label="Upload CSV",
153
  file_types=[".csv"]
154
  )
155
+
156
  with gr.Column(scale=3):
157
  chatbot = gr.Chatbot(height=500)
158
+ gallery = gr.Gallery(label="Outputs")
 
 
 
 
 
159
 
160
  with gr.Row():
161
  msg = gr.Textbox(
162
+ label="Message",
163
+ placeholder="Ask me to analyze your data..."
 
164
  )
165
+ clear = gr.Button("Clear")
166
+
167
  msg.submit(
168
+ process_message,
169
+ [msg, chatbot, csv_file, api_key],
170
  [chatbot, gallery]
171
  )
172
+ clear.click(lambda: ([], []), None, [chatbot, gallery])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  return demo
175