jzou19950715 commited on
Commit
069bc6a
·
verified ·
1 Parent(s): dc45114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -168
app.py CHANGED
@@ -8,57 +8,106 @@ import tempfile
8
  import sys
9
  from io import StringIO
10
  import matplotlib.pyplot as plt
11
- import base64
12
  from pathlib import Path
 
13
 
14
- def install_package(package_name):
15
- """Dynamically install any Python package"""
16
- try:
17
- subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
18
- return True
19
- except:
20
- return False
21
-
22
- def safe_execute_code(code: str, globals_dict=None):
23
- """Execute code safely and capture all outputs"""
24
- if globals_dict is None:
25
- globals_dict = {}
26
-
27
- # Redirect stdout to capture print outputs
28
- old_stdout = sys.stdout
29
- redirected_output = StringIO()
30
- sys.stdout = redirected_output
31
-
32
- try:
33
- # First pass: collect and install required imports
34
- import_lines = [line for line in code.split('\n') if 'import' in line]
35
- for line in import_lines:
36
- parts = line.split()
37
- if parts[0] == 'import':
38
- package = parts[1].split('.')[0]
39
- install_package(package)
40
- elif parts[0] == 'from':
41
- package = parts[1].split('.')[0]
42
- install_package(package)
43
-
44
- # Execute the code
45
- exec(code, globals_dict)
46
- output = redirected_output.getvalue()
47
 
48
- # Handle any matplotlib figures
49
- figures = []
50
- if plt.get_figs():
51
- for i, fig in enumerate(plt.get_figs()):
52
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
53
- fig.savefig(tmp.name)
54
- figures.append(tmp.name)
55
- plt.close('all')
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- return True, output, figures
58
- except Exception as e:
59
- return False, f"Error executing code:\n{str(e)}", []
60
- finally:
61
- sys.stdout = old_stdout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def query_deepseek(prompt: str, api_key: str, system_prompt: str = None):
64
  """Send a prompt to DeepSeek API"""
@@ -87,74 +136,56 @@ def query_deepseek(prompt: str, api_key: str, system_prompt: str = None):
87
  except Exception as e:
88
  return f"API Error: {str(e)}"
89
 
90
- def chat_function(message, history, csv_file, api_key, system_prompt):
91
- """Handle chat interactions"""
92
- if not api_key:
93
- return "Please provide your DeepSeek API key first."
94
-
95
- context = ""
96
- if csv_file:
97
- df = pd.read_csv(csv_file.name)
98
- context = f"\nContext: I have loaded a CSV file with columns: {df.columns.tolist()}\n"
99
- context += f"First few rows: {df.head(3).to_dict()}\n"
100
-
101
- full_prompt = context + message
102
- response = query_deepseek(full_prompt, api_key, system_prompt)
103
- return response
104
-
105
- def analyze_data(csv_file, api_key, system_prompt, code_request):
106
- """Generate and execute code for data analysis"""
107
- if not csv_file:
108
- return "Please upload a CSV file first.", None, None, []
109
 
110
- if not api_key:
111
- return "Please provide your DeepSeek API key.", None, None, []
112
-
113
- try:
114
- # Read the CSV file
115
- df = pd.read_csv(csv_file.name)
116
 
117
- # Build the prompt
118
- prompt = f"""I have a CSV file with columns: {df.columns.tolist()}.
119
- First few rows: {df.head(3).to_dict()}.
120
-
121
- User request: {code_request}
122
-
123
- Please generate Python code that:
124
- 1. Analyzes the data according to the request
125
- 2. Creates relevant visualizations
126
- 3. Handles potential errors and edge cases
127
- 4. Includes helpful comments"""
128
-
129
- # Get code from API
130
- generated_code = query_deepseek(prompt, api_key, system_prompt)
131
 
132
- # Set up execution environment
133
- globals_dict = {
134
- 'pd': pd,
135
- 'plt': plt,
136
- 'df': df,
137
- 'np': __import__('numpy')
138
- }
139
-
140
- # Execute the code
141
- success, execution_output, figures = safe_execute_code(generated_code, globals_dict)
142
 
143
- if not success:
144
- return f"Execution failed: {execution_output}", generated_code, None, []
145
-
146
- return "Analysis completed successfully.", generated_code, execution_output, figures
147
-
148
- except Exception as e:
149
- return f"Error during analysis: {str(e)}", None, None, []
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  def create_interface():
152
- """Create the dual-channel Gradio interface"""
153
- with gr.Blocks() as interface:
154
- gr.Markdown("# AI Data Analysis Assistant")
 
 
155
 
156
  with gr.Row():
157
- # Sidebar with common inputs
158
  with gr.Column(scale=1):
159
  api_key = gr.Textbox(
160
  label="DeepSeek API Key",
@@ -163,7 +194,7 @@ def create_interface():
163
  )
164
  system_prompt = gr.Textbox(
165
  label="System Prompt",
166
- value="You are an AI assistant specialized in data analysis and Python programming.",
167
  lines=3
168
  )
169
  csv_file = gr.File(
@@ -171,72 +202,45 @@ def create_interface():
171
  file_types=[".csv"]
172
  )
173
 
174
- # Main content area with tabs
175
  with gr.Column(scale=3):
176
- with gr.Tabs():
177
- # Chat Interface Tab
178
- with gr.TabItem("Chat"):
179
- chatbot = gr.Chatbot()
180
- msg = gr.Textbox(label="Your Message")
181
- clear = gr.Button("Clear Chat")
182
-
183
- msg.submit(
184
- chat_function,
185
- [msg, chatbot, csv_file, api_key, system_prompt],
186
- chatbot
187
- )
188
- clear.click(lambda: None, None, chatbot, queue=False)
189
-
190
- # Code Generation Tab
191
- with gr.TabItem("Code Generation"):
192
- code_request = gr.Textbox(
193
- label="What analysis would you like to perform?",
194
- placeholder="e.g., Create a correlation matrix and visualize key relationships",
195
- lines=3
196
- )
197
- analyze_button = gr.Button("Generate & Execute Code")
198
-
199
- with gr.Row():
200
- with gr.Column():
201
- status_output = gr.Textbox(label="Status")
202
- code_output = gr.Code(
203
- label="Generated Code",
204
- language="python"
205
- )
206
- execution_output = gr.Textbox(
207
- label="Execution Output",
208
- lines=10
209
- )
210
- with gr.Column():
211
- gallery = gr.Gallery(
212
- label="Visualizations",
213
- columns=2,
214
- rows=2,
215
- height="auto"
216
- )
217
-
218
- analyze_button.click(
219
- analyze_data,
220
- inputs=[csv_file, api_key, system_prompt, code_request],
221
- outputs=[status_output, code_output, execution_output, gallery]
222
- )
223
-
224
  gr.Markdown("""
225
  ## How to Use
226
  1. Enter your DeepSeek API key
227
  2. Upload a CSV file for analysis
228
- 3. Use either:
229
- - Chat tab: Have a conversation about your data
230
- - Code Generation tab: Get executable Python code for specific analyses
231
 
232
- The tool will:
233
- - Generate and execute Python code
234
- - Create visualizations
235
- - Allow interactive exploration of your data
 
 
 
 
 
236
  """)
237
-
238
- return interface
239
 
240
  if __name__ == "__main__":
241
- interface = create_interface()
242
- interface.launch()
 
8
  import sys
9
  from io import StringIO
10
  import matplotlib.pyplot as plt
11
+ import re
12
  from pathlib import Path
13
+ import importlib
14
 
15
+ class CodeExecutionEnvironment:
16
+ def __init__(self):
17
+ self.globals_dict = {}
18
+ self.figures_dir = "temp_figures"
19
+ os.makedirs(self.figures_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def install_package(self, package_name):
22
+ """Dynamically install a Python package"""
23
+ try:
24
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_name],
25
+ stdout=subprocess.DEVNULL,
26
+ stderr=subprocess.DEVNULL)
27
+ return True
28
+ except:
29
+ return False
30
+
31
+ def extract_and_execute_code(self, text):
32
+ """Extract code blocks from markdown and execute them"""
33
+ # Pattern for code blocks
34
+ code_blocks = re.findall(r'```python(.*?)```', text, re.DOTALL)
35
+ if not code_blocks:
36
+ return text, None, []
37
+
38
+ all_outputs = []
39
+ all_figures = []
40
 
41
+ for code in code_blocks:
42
+ success, output, figures = self.execute_code(code.strip())
43
+ if success:
44
+ all_outputs.append(output)
45
+ all_figures.extend(figures)
46
+ else:
47
+ all_outputs.append(f"Error: {output}")
48
+
49
+ # Replace code blocks with code + output
50
+ modified_text = text
51
+ for i, (code, output) in enumerate(zip(code_blocks, all_outputs)):
52
+ code_section = f"```python{code}```"
53
+ output_section = f"\nOutput:\n```\n{output}\n```"
54
+ modified_text = modified_text.replace(code_section, code_section + output_section)
55
+
56
+ return modified_text, "\n".join(all_outputs), all_figures
57
+
58
+ def execute_code(self, code):
59
+ """Execute code in the managed environment"""
60
+ # Redirect stdout to capture prints
61
+ old_stdout = sys.stdout
62
+ redirected_output = StringIO()
63
+ sys.stdout = redirected_output
64
+
65
+ try:
66
+ # First pass: collect and install required imports
67
+ import_lines = [line for line in code.split('\n') if 'import' in line]
68
+ for line in import_lines:
69
+ parts = line.split()
70
+ if parts[0] == 'import':
71
+ package = parts[1].split('.')[0]
72
+ if package not in sys.modules:
73
+ self.install_package(package)
74
+ try:
75
+ self.globals_dict[package] = importlib.import_module(package)
76
+ except:
77
+ pass
78
+ elif parts[0] == 'from':
79
+ package = parts[1].split('.')[0]
80
+ if package not in sys.modules:
81
+ self.install_package(package)
82
+
83
+ # Add common data science packages to globals
84
+ if 'pd' not in self.globals_dict:
85
+ self.globals_dict['pd'] = pd
86
+ if 'plt' not in self.globals_dict:
87
+ self.globals_dict['plt'] = plt
88
+ if 'np' not in self.globals_dict:
89
+ import numpy as np
90
+ self.globals_dict['np'] = np
91
+
92
+ # Execute the code
93
+ exec(code, self.globals_dict)
94
+ output = redirected_output.getvalue()
95
+
96
+ # Capture figures
97
+ figures = []
98
+ if plt.get_figs():
99
+ for i, fig in enumerate(plt.get_figs()):
100
+ fig_path = os.path.join(self.figures_dir, f"figure_{len(figures)}.png")
101
+ fig.savefig(fig_path)
102
+ figures.append(fig_path)
103
+ plt.close('all')
104
+
105
+ return True, output, figures
106
+
107
+ except Exception as e:
108
+ return False, str(e), []
109
+ finally:
110
+ sys.stdout = old_stdout
111
 
112
  def query_deepseek(prompt: str, api_key: str, system_prompt: str = None):
113
  """Send a prompt to DeepSeek API"""
 
136
  except Exception as e:
137
  return f"API Error: {str(e)}"
138
 
139
+ class ChatAndCodeInterface:
140
+ def __init__(self):
141
+ self.env = CodeExecutionEnvironment()
142
+ self.current_df = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ def process_message(self, message, history, csv_file, api_key, system_prompt):
145
+ """Process a chat message with code execution capabilities"""
146
+ if not api_key:
147
+ return history + [[message, "Please provide your DeepSeek API key first."]], None
 
 
148
 
149
+ # Update dataframe if new CSV uploaded
150
+ if csv_file and (self.current_df is None or csv_file.name != getattr(self.current_df, '_filename', None)):
151
+ self.current_df = pd.read_csv(csv_file.name)
152
+ self.current_df._filename = csv_file.name
153
+ self.env.globals_dict['df'] = self.current_df
 
 
 
 
 
 
 
 
 
154
 
155
+ # Build context
156
+ context = ""
157
+ if self.current_df is not None:
158
+ context = (f"\nContext: Working with CSV file containing columns: {self.current_df.columns.tolist()}\n"
159
+ f"First few rows: {self.current_df.head(3).to_dict()}\n"
160
+ f"The dataframe is available as 'df' in the code environment.\n")
 
 
 
 
161
 
162
+ # Get AI response
163
+ full_prompt = (
164
+ context +
165
+ "The user might ask you to analyze data or generate visualizations. "
166
+ "When you write code, wrap it in ```python``` blocks. "
167
+ "You can use any Python library - they will be automatically installed. "
168
+ "\nUser message: " + message
169
+ )
170
+
171
+ response = query_deepseek(full_prompt, api_key, system_prompt)
172
+
173
+ # Execute any code in the response
174
+ modified_response, outputs, figures = self.env.extract_and_execute_code(response)
175
+
176
+ # Update chat history
177
+ history = history + [[message, modified_response]]
178
+
179
+ return history, figures
180
 
181
  def create_interface():
182
+ """Create the unified chat and code execution interface"""
183
+ interface = ChatAndCodeInterface()
184
+
185
+ with gr.Blocks() as demo:
186
+ gr.Markdown("# AI Data Analysis Assistant with Code Execution")
187
 
188
  with gr.Row():
 
189
  with gr.Column(scale=1):
190
  api_key = gr.Textbox(
191
  label="DeepSeek API Key",
 
194
  )
195
  system_prompt = gr.Textbox(
196
  label="System Prompt",
197
+ value="You are an AI assistant specialized in data analysis and Python programming. When asked to analyze data or create visualizations, you provide executable Python code.",
198
  lines=3
199
  )
200
  csv_file = gr.File(
 
202
  file_types=[".csv"]
203
  )
204
 
 
205
  with gr.Column(scale=3):
206
+ chatbot = gr.Chatbot(height=400)
207
+ gallery = gr.Gallery(label="Generated Visualizations", columns=2, height=300)
208
+
209
+ with gr.Row():
210
+ msg = gr.Textbox(
211
+ label="Your Message",
212
+ placeholder="Ask me to analyze your data or create visualizations...",
213
+ scale=9
214
+ )
215
+ clear = gr.Button("Clear", scale=1)
216
+
217
+ msg.submit(
218
+ interface.process_message,
219
+ [msg, chatbot, csv_file, api_key, system_prompt],
220
+ [chatbot, gallery]
221
+ )
222
+
223
+ clear.click(lambda: ([], []), None, [chatbot, gallery], queue=False)
224
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  gr.Markdown("""
226
  ## How to Use
227
  1. Enter your DeepSeek API key
228
  2. Upload a CSV file for analysis
229
+ 3. Chat naturally about your data analysis needs
 
 
230
 
231
+ Example prompts:
232
+ - "Create a histogram of the numerical columns"
233
+ - "Analyze the correlation between variables"
234
+ - "Generate summary statistics and visualize key trends"
235
+
236
+ The assistant will:
237
+ - Generate and execute Python code automatically
238
+ - Show both code and its output in the chat
239
+ - Display generated visualizations in the gallery
240
  """)
241
+
242
+ return demo
243
 
244
  if __name__ == "__main__":
245
+ demo = create_interface()
246
+ demo.launch()