Update app.py
Browse files
app.py
CHANGED
@@ -6,11 +6,81 @@ import subprocess
|
|
6 |
import gradio as gr
|
7 |
import tempfile
|
8 |
import sys
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
from io import StringIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
def query_api(prompt, api_url, api_key, system_prompt):
|
13 |
-
"""Send a prompt to the specified API and return the response
|
14 |
headers = {
|
15 |
"Content-Type": "application/json",
|
16 |
"Authorization": f"Bearer {api_key}"
|
@@ -20,8 +90,7 @@ def query_api(prompt, api_url, api_key, system_prompt):
|
|
20 |
"messages": [
|
21 |
{"role": "system", "content": system_prompt},
|
22 |
{"role": "user", "content": prompt}
|
23 |
-
]
|
24 |
-
"stream": False
|
25 |
}
|
26 |
|
27 |
try:
|
@@ -31,94 +100,55 @@ def query_api(prompt, api_url, api_key, system_prompt):
|
|
31 |
except requests.exceptions.RequestException as e:
|
32 |
return f"API Error: {str(e)}"
|
33 |
|
34 |
-
def
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
def safe_execute_code(code, globals_dict=None):
|
43 |
-
"""Safely execute the generated Python code in a restricted environment."""
|
44 |
-
if globals_dict is None:
|
45 |
-
globals_dict = {}
|
46 |
|
47 |
-
# Redirect stdout to capture print outputs
|
48 |
-
old_stdout = sys.stdout
|
49 |
-
redirected_output = StringIO()
|
50 |
-
sys.stdout = redirected_output
|
51 |
-
|
52 |
-
try:
|
53 |
-
# Execute the code in the restricted environment
|
54 |
-
exec(code, globals_dict)
|
55 |
-
output = redirected_output.getvalue()
|
56 |
-
return True, output
|
57 |
-
except Exception as e:
|
58 |
-
return False, f"Error executing code: {str(e)}"
|
59 |
-
finally:
|
60 |
-
sys.stdout = old_stdout
|
61 |
-
|
62 |
-
def analyze_data(csv_file, api_url, api_key, system_prompt):
|
63 |
-
"""Analyze the uploaded CSV file using the specified API."""
|
64 |
if not csv_file:
|
65 |
-
return "No file uploaded.", None, None
|
66 |
-
|
67 |
try:
|
|
|
|
|
|
|
68 |
# Read the CSV file
|
69 |
df = pd.read_csv(csv_file.name)
|
70 |
columns = df.columns.tolist()
|
71 |
sample_data = df.head(3).to_dict()
|
72 |
|
73 |
# Build the prompt
|
74 |
-
prompt =
|
75 |
-
|
76 |
-
f"The first few rows are: {sample_data}. "
|
77 |
-
"Please generate Python code to analyze this data. Include:"
|
78 |
-
"1. Basic statistical analysis"
|
79 |
-
"2. Data visualization using matplotlib or seaborn"
|
80 |
-
"3. Any interesting patterns or insights"
|
81 |
-
"Make sure to use only standard data science libraries."
|
82 |
-
)
|
83 |
-
|
84 |
-
# Get code from API
|
85 |
-
generated_code = query_api(prompt, api_url, api_key, system_prompt)
|
86 |
-
|
87 |
-
# Create a temporary directory for generated files
|
88 |
-
with tempfile.TemporaryDirectory() as temp_dir:
|
89 |
-
os.chdir(temp_dir)
|
90 |
-
|
91 |
-
# Save the DataFrame in the temporary directory
|
92 |
-
df.to_csv("input_data.csv", index=False)
|
93 |
-
|
94 |
-
# Prepare the execution environment
|
95 |
-
globals_dict = {
|
96 |
-
'pd': pd,
|
97 |
-
'plt': plt,
|
98 |
-
'df': df,
|
99 |
-
'__file__': 'input_data.csv'
|
100 |
-
}
|
101 |
-
|
102 |
-
# Execute the code
|
103 |
-
success, execution_output = safe_execute_code(generated_code, globals_dict)
|
104 |
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
if plt.get_figs():
|
110 |
-
plt.savefig("visualization.png")
|
111 |
-
plt.close('all')
|
112 |
-
if os.path.exists("visualization.png"):
|
113 |
-
return "Analysis completed successfully.", generated_code, (execution_output, "visualization.png")
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
except Exception as e:
|
118 |
-
return f"Error during analysis: {str(e)}", None, None
|
119 |
|
120 |
-
# Create Gradio interface
|
121 |
def create_interface():
|
|
|
122 |
with gr.Blocks() as interface:
|
123 |
gr.Markdown("# AI-Powered Data Analysis Tool")
|
124 |
|
@@ -126,18 +156,18 @@ def create_interface():
|
|
126 |
with gr.Column():
|
127 |
api_url = gr.Textbox(
|
128 |
label="API URL",
|
129 |
-
placeholder="Enter
|
130 |
type="text"
|
131 |
)
|
132 |
api_key = gr.Textbox(
|
133 |
label="API Key",
|
134 |
-
placeholder="Enter
|
135 |
type="password"
|
136 |
)
|
137 |
system_prompt = gr.Textbox(
|
138 |
label="System Prompt",
|
139 |
placeholder="Enter system prompt for the AI",
|
140 |
-
value="You are an AI assistant specialized in data analysis
|
141 |
lines=3
|
142 |
)
|
143 |
csv_file = gr.File(
|
@@ -152,30 +182,33 @@ def create_interface():
|
|
152 |
label="Generated Code",
|
153 |
language="python"
|
154 |
)
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
)
|
164 |
|
165 |
analyze_button.click(
|
166 |
fn=analyze_data,
|
167 |
inputs=[csv_file, api_url, api_key, system_prompt],
|
168 |
-
outputs=[status_output, code_output,
|
169 |
)
|
170 |
|
171 |
gr.Markdown("""
|
172 |
## How to Use
|
173 |
-
1. Enter your API URL and key
|
174 |
2. Customize the system prompt if desired
|
175 |
-
3. Upload a CSV file
|
176 |
4. Click 'Analyze Data' to generate and execute analysis code
|
177 |
|
178 |
-
The tool will
|
|
|
|
|
|
|
|
|
179 |
""")
|
180 |
|
181 |
return interface
|
|
|
6 |
import gradio as gr
|
7 |
import tempfile
|
8 |
import sys
|
|
|
9 |
from io import StringIO
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
import numpy as np
|
13 |
+
from typing import Dict, Any, Tuple, Optional
|
14 |
+
import ast
|
15 |
+
|
16 |
+
# Safe imports list - mirrors smolagents approach
|
17 |
+
SAFE_IMPORTS = [
|
18 |
+
"pandas", "numpy", "matplotlib", "seaborn", "sklearn",
|
19 |
+
"scipy", "statsmodels", "plotly", "math", "datetime",
|
20 |
+
"collections", "itertools", "functools", "operator"
|
21 |
+
]
|
22 |
+
|
23 |
+
class SafeExecutor:
|
24 |
+
"""Safely executes Python code with restricted imports and environment"""
|
25 |
+
|
26 |
+
def __init__(self, allowed_imports=None):
|
27 |
+
self.allowed_imports = allowed_imports or SAFE_IMPORTS
|
28 |
+
|
29 |
+
def validate_imports(self, code: str) -> bool:
|
30 |
+
"""Validate that all imports in the code are allowed"""
|
31 |
+
try:
|
32 |
+
tree = ast.parse(code)
|
33 |
+
for node in ast.walk(tree):
|
34 |
+
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
35 |
+
for name in node.names:
|
36 |
+
module = name.name.split('.')[0]
|
37 |
+
if module not in self.allowed_imports:
|
38 |
+
raise ValueError(f"Import of '{module}' is not allowed. Allowed imports: {self.allowed_imports}")
|
39 |
+
return True
|
40 |
+
except Exception as e:
|
41 |
+
raise ValueError(f"Code validation error: {str(e)}")
|
42 |
+
|
43 |
+
def execute_code(self, code: str, globals_dict: Dict[str, Any] = None) -> Tuple[Any, str]:
|
44 |
+
"""Execute code safely and return the output"""
|
45 |
+
if globals_dict is None:
|
46 |
+
globals_dict = {}
|
47 |
+
|
48 |
+
# Add safe imports to globals
|
49 |
+
for module in self.allowed_imports:
|
50 |
+
try:
|
51 |
+
globals_dict[module] = __import__(module)
|
52 |
+
except ImportError:
|
53 |
+
pass
|
54 |
+
|
55 |
+
# Redirect stdout to capture print outputs
|
56 |
+
old_stdout = sys.stdout
|
57 |
+
redirected_output = StringIO()
|
58 |
+
sys.stdout = redirected_output
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Validate imports first
|
62 |
+
self.validate_imports(code)
|
63 |
+
|
64 |
+
# Execute the code
|
65 |
+
exec(code, globals_dict)
|
66 |
+
output = redirected_output.getvalue()
|
67 |
+
|
68 |
+
# Handle matplotlib figures
|
69 |
+
if plt.get_figs():
|
70 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
|
71 |
+
plt.savefig(tmp.name)
|
72 |
+
plt.close('all')
|
73 |
+
return tmp.name, output
|
74 |
+
|
75 |
+
return None, output
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
return None, f"Error executing code:\n{str(e)}"
|
79 |
+
finally:
|
80 |
+
sys.stdout = old_stdout
|
81 |
|
82 |
+
def query_api(prompt: str, api_url: str, api_key: str, system_prompt: str) -> str:
|
83 |
+
"""Send a prompt to the specified API and return the response"""
|
84 |
headers = {
|
85 |
"Content-Type": "application/json",
|
86 |
"Authorization": f"Bearer {api_key}"
|
|
|
90 |
"messages": [
|
91 |
{"role": "system", "content": system_prompt},
|
92 |
{"role": "user", "content": prompt}
|
93 |
+
]
|
|
|
94 |
}
|
95 |
|
96 |
try:
|
|
|
100 |
except requests.exceptions.RequestException as e:
|
101 |
return f"API Error: {str(e)}"
|
102 |
|
103 |
+
def analyze_data(
|
104 |
+
csv_file: str,
|
105 |
+
api_url: str,
|
106 |
+
api_key: str,
|
107 |
+
system_prompt: str
|
108 |
+
) -> Tuple[str, str, str, Optional[str]]:
|
109 |
+
"""Analyze uploaded CSV data using the API and execute the generated code"""
|
|
|
|
|
|
|
|
|
|
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if not csv_file:
|
112 |
+
return "No file uploaded.", None, None, None
|
113 |
+
|
114 |
try:
|
115 |
+
# Create safe executor
|
116 |
+
executor = SafeExecutor()
|
117 |
+
|
118 |
# Read the CSV file
|
119 |
df = pd.read_csv(csv_file.name)
|
120 |
columns = df.columns.tolist()
|
121 |
sample_data = df.head(3).to_dict()
|
122 |
|
123 |
# Build the prompt
|
124 |
+
prompt = f"""Analyze this CSV file with columns: {columns}.
|
125 |
+
Sample data: {sample_data}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
Generate Python code that:
|
128 |
+
1. Creates insightful visualizations using matplotlib or seaborn
|
129 |
+
2. Performs relevant statistical analysis
|
130 |
+
3. Identifies key patterns or insights
|
131 |
+
4. Properly handles potential data issues
|
132 |
|
133 |
+
Important: Use only these libraries: {', '.join(SAFE_IMPORTS)}"""
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
# Get code from API
|
136 |
+
generated_code = query_api(prompt, api_url, api_key, system_prompt)
|
137 |
+
|
138 |
+
# Create execution environment
|
139 |
+
globals_dict = {'df': df, 'pd': pd, 'np': np, 'plt': plt, 'sns': sns}
|
140 |
+
|
141 |
+
# Execute the code
|
142 |
+
vis_path, execution_output = executor.execute_code(generated_code, globals_dict)
|
143 |
+
|
144 |
+
status = "Analysis completed successfully."
|
145 |
+
return status, generated_code, execution_output, vis_path
|
146 |
|
147 |
except Exception as e:
|
148 |
+
return f"Error during analysis: {str(e)}", None, None, None
|
149 |
|
|
|
150 |
def create_interface():
|
151 |
+
"""Create the Gradio interface"""
|
152 |
with gr.Blocks() as interface:
|
153 |
gr.Markdown("# AI-Powered Data Analysis Tool")
|
154 |
|
|
|
156 |
with gr.Column():
|
157 |
api_url = gr.Textbox(
|
158 |
label="API URL",
|
159 |
+
placeholder="Enter API endpoint URL",
|
160 |
type="text"
|
161 |
)
|
162 |
api_key = gr.Textbox(
|
163 |
label="API Key",
|
164 |
+
placeholder="Enter API key",
|
165 |
type="password"
|
166 |
)
|
167 |
system_prompt = gr.Textbox(
|
168 |
label="System Prompt",
|
169 |
placeholder="Enter system prompt for the AI",
|
170 |
+
value="You are an AI assistant specialized in data analysis and visualization.",
|
171 |
lines=3
|
172 |
)
|
173 |
csv_file = gr.File(
|
|
|
182 |
label="Generated Code",
|
183 |
language="python"
|
184 |
)
|
185 |
+
execution_output = gr.Textbox(
|
186 |
+
label="Execution Output",
|
187 |
+
lines=10
|
188 |
+
)
|
189 |
+
visualization_output = gr.Image(
|
190 |
+
label="Visualization",
|
191 |
+
type="filepath"
|
192 |
+
)
|
|
|
193 |
|
194 |
analyze_button.click(
|
195 |
fn=analyze_data,
|
196 |
inputs=[csv_file, api_url, api_key, system_prompt],
|
197 |
+
outputs=[status_output, code_output, execution_output, visualization_output]
|
198 |
)
|
199 |
|
200 |
gr.Markdown("""
|
201 |
## How to Use
|
202 |
+
1. Enter your API URL and key (supports various API providers)
|
203 |
2. Customize the system prompt if desired
|
204 |
+
3. Upload a CSV file for analysis
|
205 |
4. Click 'Analyze Data' to generate and execute analysis code
|
206 |
|
207 |
+
The tool will:
|
208 |
+
- Generate Python code to analyze your data
|
209 |
+
- Execute the code safely in a controlled environment
|
210 |
+
- Display both textual results and visualizations
|
211 |
+
- Support common data science libraries
|
212 |
""")
|
213 |
|
214 |
return interface
|