jzou19950715 commited on
Commit
882008c
·
verified ·
1 Parent(s): 7a8e2c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -193
app.py CHANGED
@@ -1,210 +1,185 @@
1
- # app.py
2
- import streamlit as st
3
- import google.generativeai as generativeai
4
  import os
5
- import re
6
- import json
7
- import logging
8
  import pandas as pd
9
- import plotly.express as px
10
- import plotly.graph_objects as go
11
- import seaborn as sns
 
 
 
12
  import matplotlib.pyplot as plt
13
- import numpy as np
14
  from io import StringIO
15
 
16
- def load_data(uploaded_file):
17
- try:
18
- df = pd.read_csv(uploaded_file)
19
- return df
20
- except Exception as e:
21
- st.error(f"Error: {str(e)}")
22
- return None
23
-
24
- def get_numeric_columns(df):
25
- return df.select_dtypes(include=['float64', 'int64']).columns
26
-
27
- def get_categorical_columns(df):
28
- return df.select_dtypes(include=['object', 'category']).columns
 
29
 
30
- # Configure logging
31
- logging.basicConfig(
32
- level=logging.INFO,
33
- format='%(asctime)s - %(levelname)s - %(message)s',
34
- handlers=[logging.StreamHandler()]
35
- )
36
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def configure_gemini():
39
- """Configure Google's Gemini AI model."""
40
  try:
41
- from dotenv import load_dotenv
42
- load_dotenv()
43
- api_key = os.getenv("GOOGLE_API_KEY")
44
- if not api_key:
45
- st.error("Please set your GOOGLE_API_KEY in the .env file")
46
- return None
47
-
48
- generativeai.configure(api_key=api_key)
49
- return generativeai.GenerativeModel('gemini-1.0-pro')
50
  except Exception as e:
51
- st.error(f"Error configuring Gemini: {str(e)}")
52
- return None
53
-
54
- def get_ai_visualization_suggestion(df, user_query):
55
- """Get AI-powered visualization suggestions based on the data and user query."""
56
- model = configure_gemini()
57
- if not model:
58
- return None
59
-
60
- # Create a prompt for the AI
61
- columns_info = {
62
- 'column_names': list(df.columns),
63
- 'data_types': {col: str(df[col].dtype) for col in df.columns},
64
- 'sample_values': {col: df[col].head().tolist() for col in df.columns}
65
- }
66
-
67
- prompt = f"""
68
- Analyze this dataset and the user's query to suggest the best visualization approach:
69
-
70
- User Query: {user_query}
71
-
72
- Dataset Information:
73
- {json.dumps(columns_info, indent=2)}
74
-
75
- Please suggest:
76
- 1. The most appropriate type of visualization
77
- 2. Which columns should be used
78
- 3. Any data transformations needed
79
- 4. Visualization parameters (like color schemes, labels, etc.)
80
-
81
- Format your response as JSON with the following structure:
82
- {{
83
- "viz_type": "type of visualization",
84
- "columns": ["column1", "column2"],
85
- "transformations": ["transformation1", "transformation2"],
86
- "parameters": {{
87
- "param1": "value1",
88
- "param2": "value2"
89
- }}
90
- }}
91
- """
92
 
93
  try:
94
- response = model.generate_content(prompt)
95
- # Extract JSON from response
96
- suggestion = json.loads(response.text)
97
- return suggestion
98
- except Exception as e:
99
- logger.error(f"Error getting AI suggestion: {str(e)}")
100
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- def main():
103
- st.title("📊 AI-Powered Data Visualization Dashboard")
104
- st.write("Upload your CSV file and explore the data through various visualizations!")
105
 
106
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
 
 
 
 
107
 
108
- if uploaded_file is not None:
109
- df = load_data(uploaded_file)
 
 
 
 
 
 
 
110
 
111
- if df is not None:
112
- st.success("File successfully loaded!")
113
-
114
- # Basic Data Info
115
- st.header("📝 Data Overview")
116
- st.write(f"Number of rows: {df.shape[0]}")
117
- st.write(f"Number of columns: {df.shape[1]}")
118
-
119
- # Data Preview
120
- st.subheader("Data Preview")
121
- st.dataframe(df.head())
122
-
123
- # Missing Values Analysis
124
- st.subheader("Missing Values Analysis")
125
- missing_data = df.isnull().sum()
126
- if missing_data.sum() > 0:
127
- st.write("Missing values by column:")
128
- st.write(missing_data[missing_data > 0])
129
- else:
130
- st.write("No missing values found in the dataset!")
131
-
132
- # User Query for AI Suggestions
133
- st.header("🤖 AI-Powered Visualization")
134
- user_query = st.text_input("Describe what you want to visualize",
135
- "Show me trends in the data")
136
-
137
- if st.button("Get AI Suggestion"):
138
- with st.spinner("Getting AI visualization
139
-
140
- viz_type = st.selectbox(
141
- "Choose visualization type",
142
- ["Scatter Plot", "Line Plot", "Bar Plot", "Histogram", "Box Plot", "Correlation Heatmap"]
143
- )
144
-
145
- numeric_columns = get_numeric_columns(df)
146
- categorical_columns = get_categorical_columns(df)
147
-
148
- if viz_type == "Scatter Plot" and len(numeric_columns) >= 2:
149
- x_col = st.selectbox("Select X axis", numeric_columns)
150
- y_col = st.selectbox("Select Y axis", numeric_columns)
151
- color_col = st.selectbox("Select Color variable (optional)",
152
- ["None"] + list(df.columns))
153
-
154
- if color_col == "None":
155
- fig = px.scatter(df, x=x_col, y=y_col)
156
- else:
157
- fig = px.scatter(df, x=x_col, y=y_col, color=color_col)
158
- st.plotly_chart(fig)
159
-
160
- elif viz_type == "Line Plot" and len(numeric_columns) >= 1:
161
- x_col = st.selectbox("Select X axis", df.columns)
162
- y_col = st.selectbox("Select Y axis", numeric_columns)
163
- fig = px.line(df, x=x_col, y=y_col)
164
- st.plotly_chart(fig)
165
-
166
- elif viz_type == "Bar Plot":
167
- x_col = st.selectbox("Select X axis", df.columns)
168
- y_col = st.selectbox("Select Y axis", numeric_columns)
169
- fig = px.bar(df, x=x_col, y=y_col)
170
- st.plotly_chart(fig)
171
-
172
- elif viz_type == "Histogram" and len(numeric_columns) >= 1:
173
- col = st.selectbox("Select column", numeric_columns)
174
- bins = st.slider("Number of bins", min_value=5, max_value=100, value=30)
175
- fig = px.histogram(df, x=col, nbins=bins)
176
- st.plotly_chart(fig)
177
-
178
- elif viz_type == "Box Plot" and len(numeric_columns) >= 1:
179
- y_col = st.selectbox("Select column for box plot", numeric_columns)
180
- x_col = st.selectbox("Select grouping variable (optional)",
181
- ["None"] + list(categorical_columns))
182
-
183
- if x_col == "None":
184
- fig = px.box(df, y=y_col)
185
- else:
186
- fig = px.box(df, x=x_col, y=y_col)
187
- st.plotly_chart(fig)
188
-
189
- elif viz_type == "Correlation Heatmap" and len(numeric_columns) >= 2:
190
- corr_matrix = df[numeric_columns].corr()
191
- fig = px.imshow(corr_matrix,
192
- labels=dict(color="Correlation"),
193
- x=corr_matrix.columns,
194
- y=corr_matrix.columns)
195
- st.plotly_chart(fig)
196
-
197
- # Data Summary
198
- st.header("📊 Data Summary")
199
- if len(numeric_columns) > 0:
200
- st.subheader("Numerical Columns Summary")
201
- st.write(df[numeric_columns].describe())
202
-
203
- if len(categorical_columns) > 0:
204
- st.subheader("Categorical Columns Summary")
205
- for col in categorical_columns:
206
- st.write(f"\nValue counts for {col}:")
207
- st.write(df[col].value_counts())
208
 
209
  if __name__ == "__main__":
210
- main()
 
 
 
 
 
1
  import os
 
 
 
2
  import pandas as pd
3
+ import requests
4
+ import json
5
+ import subprocess
6
+ import gradio as gr
7
+ import tempfile
8
+ import sys
9
  import matplotlib.pyplot as plt
 
10
  from io import StringIO
11
 
12
+ def query_api(prompt, api_url, api_key, system_prompt):
13
+ """Send a prompt to the specified API and return the response."""
14
+ headers = {
15
+ "Content-Type": "application/json",
16
+ "Authorization": f"Bearer {api_key}"
17
+ }
18
+
19
+ payload = {
20
+ "messages": [
21
+ {"role": "system", "content": system_prompt},
22
+ {"role": "user", "content": prompt}
23
+ ],
24
+ "stream": False
25
+ }
26
 
27
+ try:
28
+ response = requests.post(api_url, headers=headers, json=payload)
29
+ response.raise_for_status()
30
+ return response.json()["choices"][0]["message"]["content"]
31
+ except requests.exceptions.RequestException as e:
32
+ return f"API Error: {str(e)}"
33
+
34
+ def install_package(package):
35
+ """Install a Python package using pip."""
36
+ try:
37
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
38
+ return True
39
+ except subprocess.CalledProcessError:
40
+ return False
41
+
42
+ def safe_execute_code(code, globals_dict=None):
43
+ """Safely execute the generated Python code in a restricted environment."""
44
+ if globals_dict is None:
45
+ globals_dict = {}
46
+
47
+ # Redirect stdout to capture print outputs
48
+ old_stdout = sys.stdout
49
+ redirected_output = StringIO()
50
+ sys.stdout = redirected_output
51
 
 
 
52
  try:
53
+ # Execute the code in the restricted environment
54
+ exec(code, globals_dict)
55
+ output = redirected_output.getvalue()
56
+ return True, output
 
 
 
 
 
57
  except Exception as e:
58
+ return False, f"Error executing code: {str(e)}"
59
+ finally:
60
+ sys.stdout = old_stdout
61
+
62
+ def analyze_data(csv_file, api_url, api_key, system_prompt):
63
+ """Analyze the uploaded CSV file using the specified API."""
64
+ if not csv_file:
65
+ return "No file uploaded.", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  try:
68
+ # Read the CSV file
69
+ df = pd.read_csv(csv_file.name)
70
+ columns = df.columns.tolist()
71
+ sample_data = df.head(3).to_dict()
72
+
73
+ # Build the prompt
74
+ prompt = (
75
+ f"I have a CSV file with columns: {columns}. "
76
+ f"The first few rows are: {sample_data}. "
77
+ "Please generate Python code to analyze this data. Include:"
78
+ "1. Basic statistical analysis"
79
+ "2. Data visualization using matplotlib or seaborn"
80
+ "3. Any interesting patterns or insights"
81
+ "Make sure to use only standard data science libraries."
82
+ )
83
+
84
+ # Get code from API
85
+ generated_code = query_api(prompt, api_url, api_key, system_prompt)
86
+
87
+ # Create a temporary directory for generated files
88
+ with tempfile.TemporaryDirectory() as temp_dir:
89
+ os.chdir(temp_dir)
90
+
91
+ # Save the DataFrame in the temporary directory
92
+ df.to_csv("input_data.csv", index=False)
93
+
94
+ # Prepare the execution environment
95
+ globals_dict = {
96
+ 'pd': pd,
97
+ 'plt': plt,
98
+ 'df': df,
99
+ '__file__': 'input_data.csv'
100
+ }
101
+
102
+ # Execute the code
103
+ success, execution_output = safe_execute_code(generated_code, globals_dict)
104
 
105
+ if not success:
106
+ return "Code execution failed.", generated_code, execution_output
 
107
 
108
+ # Save any generated plots
109
+ if plt.get_figs():
110
+ plt.savefig("visualization.png")
111
+ plt.close('all')
112
+ if os.path.exists("visualization.png"):
113
+ return "Analysis completed successfully.", generated_code, (execution_output, "visualization.png")
114
 
115
+ return "Analysis completed successfully.", generated_code, (execution_output, None)
116
+
117
+ except Exception as e:
118
+ return f"Error during analysis: {str(e)}", None, None
119
+
120
+ # Create Gradio interface
121
+ def create_interface():
122
+ with gr.Blocks() as interface:
123
+ gr.Markdown("# AI-Powered Data Analysis Tool")
124
 
125
+ with gr.Row():
126
+ with gr.Column():
127
+ api_url = gr.Textbox(
128
+ label="API URL",
129
+ placeholder="Enter your API endpoint URL",
130
+ type="text"
131
+ )
132
+ api_key = gr.Textbox(
133
+ label="API Key",
134
+ placeholder="Enter your API key",
135
+ type="password"
136
+ )
137
+ system_prompt = gr.Textbox(
138
+ label="System Prompt",
139
+ placeholder="Enter system prompt for the AI",
140
+ value="You are an AI assistant specialized in data analysis, visualization, and Python programming.",
141
+ lines=3
142
+ )
143
+ csv_file = gr.File(
144
+ label="Upload CSV File",
145
+ file_types=[".csv"]
146
+ )
147
+ analyze_button = gr.Button("Analyze Data")
148
+
149
+ with gr.Column():
150
+ status_output = gr.Textbox(label="Status")
151
+ code_output = gr.Code(
152
+ label="Generated Code",
153
+ language="python"
154
+ )
155
+ with gr.Row():
156
+ text_output = gr.Textbox(
157
+ label="Analysis Output",
158
+ lines=10
159
+ )
160
+ image_output = gr.Image(
161
+ label="Visualization",
162
+ type="filepath"
163
+ )
164
+
165
+ analyze_button.click(
166
+ fn=analyze_data,
167
+ inputs=[csv_file, api_url, api_key, system_prompt],
168
+ outputs=[status_output, code_output, [text_output, image_output]]
169
+ )
170
+
171
+ gr.Markdown("""
172
+ ## How to Use
173
+ 1. Enter your API URL and key for the AI service you want to use (e.g., OpenAI, DeepSeek)
174
+ 2. Customize the system prompt if desired
175
+ 3. Upload a CSV file
176
+ 4. Click 'Analyze Data' to generate and execute analysis code
177
+
178
+ The tool will generate Python code to analyze your data and create visualizations.
179
+ """)
180
+
181
+ return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
+ interface = create_interface()
185
+ interface.launch()