jzou19950715 commited on
Commit
fc29f53
·
verified ·
1 Parent(s): a1bec31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -93
app.py CHANGED
@@ -6,11 +6,81 @@ import subprocess
6
  import gradio as gr
7
  import tempfile
8
  import sys
9
- import matplotlib.pyplot as plt
10
  from io import StringIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def query_api(prompt, api_url, api_key, system_prompt):
13
- """Send a prompt to the specified API and return the response."""
14
  headers = {
15
  "Content-Type": "application/json",
16
  "Authorization": f"Bearer {api_key}"
@@ -20,8 +90,7 @@ def query_api(prompt, api_url, api_key, system_prompt):
20
  "messages": [
21
  {"role": "system", "content": system_prompt},
22
  {"role": "user", "content": prompt}
23
- ],
24
- "stream": False
25
  }
26
 
27
  try:
@@ -31,94 +100,55 @@ def query_api(prompt, api_url, api_key, system_prompt):
31
  except requests.exceptions.RequestException as e:
32
  return f"API Error: {str(e)}"
33
 
34
- def install_package(package):
35
- """Install a Python package using pip."""
36
- try:
37
- subprocess.check_call([sys.executable, "-m", "pip", "install", package])
38
- return True
39
- except subprocess.CalledProcessError:
40
- return False
41
-
42
- def safe_execute_code(code, globals_dict=None):
43
- """Safely execute the generated Python code in a restricted environment."""
44
- if globals_dict is None:
45
- globals_dict = {}
46
 
47
- # Redirect stdout to capture print outputs
48
- old_stdout = sys.stdout
49
- redirected_output = StringIO()
50
- sys.stdout = redirected_output
51
-
52
- try:
53
- # Execute the code in the restricted environment
54
- exec(code, globals_dict)
55
- output = redirected_output.getvalue()
56
- return True, output
57
- except Exception as e:
58
- return False, f"Error executing code: {str(e)}"
59
- finally:
60
- sys.stdout = old_stdout
61
-
62
- def analyze_data(csv_file, api_url, api_key, system_prompt):
63
- """Analyze the uploaded CSV file using the specified API."""
64
  if not csv_file:
65
- return "No file uploaded.", None, None
66
-
67
  try:
 
 
 
68
  # Read the CSV file
69
  df = pd.read_csv(csv_file.name)
70
  columns = df.columns.tolist()
71
  sample_data = df.head(3).to_dict()
72
 
73
  # Build the prompt
74
- prompt = (
75
- f"I have a CSV file with columns: {columns}. "
76
- f"The first few rows are: {sample_data}. "
77
- "Please generate Python code to analyze this data. Include:"
78
- "1. Basic statistical analysis"
79
- "2. Data visualization using matplotlib or seaborn"
80
- "3. Any interesting patterns or insights"
81
- "Make sure to use only standard data science libraries."
82
- )
83
-
84
- # Get code from API
85
- generated_code = query_api(prompt, api_url, api_key, system_prompt)
86
-
87
- # Create a temporary directory for generated files
88
- with tempfile.TemporaryDirectory() as temp_dir:
89
- os.chdir(temp_dir)
90
-
91
- # Save the DataFrame in the temporary directory
92
- df.to_csv("input_data.csv", index=False)
93
-
94
- # Prepare the execution environment
95
- globals_dict = {
96
- 'pd': pd,
97
- 'plt': plt,
98
- 'df': df,
99
- '__file__': 'input_data.csv'
100
- }
101
-
102
- # Execute the code
103
- success, execution_output = safe_execute_code(generated_code, globals_dict)
104
 
105
- if not success:
106
- return "Code execution failed.", generated_code, execution_output
 
 
 
107
 
108
- # Save any generated plots
109
- if plt.get_figs():
110
- plt.savefig("visualization.png")
111
- plt.close('all')
112
- if os.path.exists("visualization.png"):
113
- return "Analysis completed successfully.", generated_code, (execution_output, "visualization.png")
114
 
115
- return "Analysis completed successfully.", generated_code, (execution_output, None)
 
 
 
 
 
 
 
 
 
 
116
 
117
  except Exception as e:
118
- return f"Error during analysis: {str(e)}", None, None
119
 
120
- # Create Gradio interface
121
  def create_interface():
 
122
  with gr.Blocks() as interface:
123
  gr.Markdown("# AI-Powered Data Analysis Tool")
124
 
@@ -126,18 +156,18 @@ def create_interface():
126
  with gr.Column():
127
  api_url = gr.Textbox(
128
  label="API URL",
129
- placeholder="Enter your API endpoint URL",
130
  type="text"
131
  )
132
  api_key = gr.Textbox(
133
  label="API Key",
134
- placeholder="Enter your API key",
135
  type="password"
136
  )
137
  system_prompt = gr.Textbox(
138
  label="System Prompt",
139
  placeholder="Enter system prompt for the AI",
140
- value="You are an AI assistant specialized in data analysis, visualization, and Python programming.",
141
  lines=3
142
  )
143
  csv_file = gr.File(
@@ -152,30 +182,33 @@ def create_interface():
152
  label="Generated Code",
153
  language="python"
154
  )
155
- with gr.Row():
156
- text_output = gr.Textbox(
157
- label="Analysis Output",
158
- lines=10
159
- )
160
- image_output = gr.Image(
161
- label="Visualization",
162
- type="filepath"
163
- )
164
 
165
  analyze_button.click(
166
  fn=analyze_data,
167
  inputs=[csv_file, api_url, api_key, system_prompt],
168
- outputs=[status_output, code_output, [text_output, image_output]]
169
  )
170
 
171
  gr.Markdown("""
172
  ## How to Use
173
- 1. Enter your API URL and key for the AI service you want to use (e.g., OpenAI, DeepSeek)
174
  2. Customize the system prompt if desired
175
- 3. Upload a CSV file
176
  4. Click 'Analyze Data' to generate and execute analysis code
177
 
178
- The tool will generate Python code to analyze your data and create visualizations.
 
 
 
 
179
  """)
180
 
181
  return interface
 
6
  import gradio as gr
7
  import tempfile
8
  import sys
 
9
  from io import StringIO
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import numpy as np
13
+ from typing import Dict, Any, Tuple, Optional
14
+ import ast
15
+
16
+ # Safe imports list - mirrors smolagents approach
17
+ SAFE_IMPORTS = [
18
+ "pandas", "numpy", "matplotlib", "seaborn", "sklearn",
19
+ "scipy", "statsmodels", "plotly", "math", "datetime",
20
+ "collections", "itertools", "functools", "operator"
21
+ ]
22
+
23
+ class SafeExecutor:
24
+ """Safely executes Python code with restricted imports and environment"""
25
+
26
+ def __init__(self, allowed_imports=None):
27
+ self.allowed_imports = allowed_imports or SAFE_IMPORTS
28
+
29
+ def validate_imports(self, code: str) -> bool:
30
+ """Validate that all imports in the code are allowed"""
31
+ try:
32
+ tree = ast.parse(code)
33
+ for node in ast.walk(tree):
34
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
35
+ for name in node.names:
36
+ module = name.name.split('.')[0]
37
+ if module not in self.allowed_imports:
38
+ raise ValueError(f"Import of '{module}' is not allowed. Allowed imports: {self.allowed_imports}")
39
+ return True
40
+ except Exception as e:
41
+ raise ValueError(f"Code validation error: {str(e)}")
42
+
43
+ def execute_code(self, code: str, globals_dict: Dict[str, Any] = None) -> Tuple[Any, str]:
44
+ """Execute code safely and return the output"""
45
+ if globals_dict is None:
46
+ globals_dict = {}
47
+
48
+ # Add safe imports to globals
49
+ for module in self.allowed_imports:
50
+ try:
51
+ globals_dict[module] = __import__(module)
52
+ except ImportError:
53
+ pass
54
+
55
+ # Redirect stdout to capture print outputs
56
+ old_stdout = sys.stdout
57
+ redirected_output = StringIO()
58
+ sys.stdout = redirected_output
59
+
60
+ try:
61
+ # Validate imports first
62
+ self.validate_imports(code)
63
+
64
+ # Execute the code
65
+ exec(code, globals_dict)
66
+ output = redirected_output.getvalue()
67
+
68
+ # Handle matplotlib figures
69
+ if plt.get_figs():
70
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
71
+ plt.savefig(tmp.name)
72
+ plt.close('all')
73
+ return tmp.name, output
74
+
75
+ return None, output
76
+
77
+ except Exception as e:
78
+ return None, f"Error executing code:\n{str(e)}"
79
+ finally:
80
+ sys.stdout = old_stdout
81
 
82
+ def query_api(prompt: str, api_url: str, api_key: str, system_prompt: str) -> str:
83
+ """Send a prompt to the specified API and return the response"""
84
  headers = {
85
  "Content-Type": "application/json",
86
  "Authorization": f"Bearer {api_key}"
 
90
  "messages": [
91
  {"role": "system", "content": system_prompt},
92
  {"role": "user", "content": prompt}
93
+ ]
 
94
  }
95
 
96
  try:
 
100
  except requests.exceptions.RequestException as e:
101
  return f"API Error: {str(e)}"
102
 
103
+ def analyze_data(
104
+ csv_file: str,
105
+ api_url: str,
106
+ api_key: str,
107
+ system_prompt: str
108
+ ) -> Tuple[str, str, str, Optional[str]]:
109
+ """Analyze uploaded CSV data using the API and execute the generated code"""
 
 
 
 
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if not csv_file:
112
+ return "No file uploaded.", None, None, None
113
+
114
  try:
115
+ # Create safe executor
116
+ executor = SafeExecutor()
117
+
118
  # Read the CSV file
119
  df = pd.read_csv(csv_file.name)
120
  columns = df.columns.tolist()
121
  sample_data = df.head(3).to_dict()
122
 
123
  # Build the prompt
124
+ prompt = f"""Analyze this CSV file with columns: {columns}.
125
+ Sample data: {sample_data}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ Generate Python code that:
128
+ 1. Creates insightful visualizations using matplotlib or seaborn
129
+ 2. Performs relevant statistical analysis
130
+ 3. Identifies key patterns or insights
131
+ 4. Properly handles potential data issues
132
 
133
+ Important: Use only these libraries: {', '.join(SAFE_IMPORTS)}"""
 
 
 
 
 
134
 
135
+ # Get code from API
136
+ generated_code = query_api(prompt, api_url, api_key, system_prompt)
137
+
138
+ # Create execution environment
139
+ globals_dict = {'df': df, 'pd': pd, 'np': np, 'plt': plt, 'sns': sns}
140
+
141
+ # Execute the code
142
+ vis_path, execution_output = executor.execute_code(generated_code, globals_dict)
143
+
144
+ status = "Analysis completed successfully."
145
+ return status, generated_code, execution_output, vis_path
146
 
147
  except Exception as e:
148
+ return f"Error during analysis: {str(e)}", None, None, None
149
 
 
150
  def create_interface():
151
+ """Create the Gradio interface"""
152
  with gr.Blocks() as interface:
153
  gr.Markdown("# AI-Powered Data Analysis Tool")
154
 
 
156
  with gr.Column():
157
  api_url = gr.Textbox(
158
  label="API URL",
159
+ placeholder="Enter API endpoint URL",
160
  type="text"
161
  )
162
  api_key = gr.Textbox(
163
  label="API Key",
164
+ placeholder="Enter API key",
165
  type="password"
166
  )
167
  system_prompt = gr.Textbox(
168
  label="System Prompt",
169
  placeholder="Enter system prompt for the AI",
170
+ value="You are an AI assistant specialized in data analysis and visualization.",
171
  lines=3
172
  )
173
  csv_file = gr.File(
 
182
  label="Generated Code",
183
  language="python"
184
  )
185
+ execution_output = gr.Textbox(
186
+ label="Execution Output",
187
+ lines=10
188
+ )
189
+ visualization_output = gr.Image(
190
+ label="Visualization",
191
+ type="filepath"
192
+ )
 
193
 
194
  analyze_button.click(
195
  fn=analyze_data,
196
  inputs=[csv_file, api_url, api_key, system_prompt],
197
+ outputs=[status_output, code_output, execution_output, visualization_output]
198
  )
199
 
200
  gr.Markdown("""
201
  ## How to Use
202
+ 1. Enter your API URL and key (supports various API providers)
203
  2. Customize the system prompt if desired
204
+ 3. Upload a CSV file for analysis
205
  4. Click 'Analyze Data' to generate and execute analysis code
206
 
207
+ The tool will:
208
+ - Generate Python code to analyze your data
209
+ - Execute the code safely in a controlled environment
210
+ - Display both textual results and visualizations
211
+ - Support common data science libraries
212
  """)
213
 
214
  return interface