jzou19950715 commited on
Commit
e7486fb
·
verified ·
1 Parent(s): e826aa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -267
app.py CHANGED
@@ -1,289 +1,210 @@
 
 
 
1
  import os
2
- import shutil
3
- import gradio as gr
 
4
  import pandas as pd
5
- import numpy as np
6
  import plotly.express as px
7
- from typing import Dict, List, Optional, Tuple
8
- from dataclasses import dataclass
9
- from openai import OpenAI # OpenAI API client
 
 
10
 
11
- # Configuration class for agent settings
12
- @dataclass
13
- class AgentConfig:
14
- """Configuration for the data science agent"""
15
- system_prompt: str = """
16
- You are an expert data scientist. Analyze the data and provide insights.
17
- Your responses should be clear, concise, and actionable. Always provide explanations
18
- for your analysis and include visualizations when appropriate.
19
- """
20
- max_iterations: int = 10
21
- temperature: float = 0.7
22
- model_name: str = "gpt-4" # Use GPT-4 or another valid model
23
 
24
- # Data Analysis State class
25
- @dataclass
26
- class AnalysisState:
27
- """Maintains state for ongoing analysis"""
28
- df: Optional[pd.DataFrame] = None
29
- current_analysis: Dict = None
30
- visualizations: List[Dict] = None
31
- error_log: List[str] = None
32
 
33
- def clear(self):
34
- self.df = None
35
- self.current_analysis = None
36
- self.visualizations = None
37
- self.error_log = []
38
 
39
- def log_error(self, error: str):
40
- if self.error_log is None:
41
- self.error_log = []
42
- self.error_log.append(error)
 
 
 
43
 
44
- # Helper functions for data processing
45
- def process_uploaded_file(file) -> Tuple[Optional[pd.DataFrame], Dict]:
46
- """Process uploaded file and return DataFrame with info"""
47
  try:
48
- if file.name.endswith('.csv'):
49
- df = pd.read_csv(file.name)
50
- elif file.name.endswith('.xlsx'):
51
- df = pd.read_excel(file.name)
52
- elif file.name.endswith('.json'):
53
- df = pd.read_json(file.name)
54
- else:
55
- return None, {"error": "Unsupported file format"}
56
-
57
- info = {
58
- "shape": df.shape,
59
- "columns": list(df.columns),
60
- "dtypes": df.dtypes.to_dict(),
61
- "missing_values": df.isnull().sum().to_dict(),
62
- "numeric_columns": list(df.select_dtypes(include=[np.number]).columns),
63
- "categorical_columns": list(df.select_dtypes(exclude=[np.number]).columns)
64
- }
65
 
66
- return df, info
 
67
  except Exception as e:
68
- return None, {"error": str(e)}
69
-
70
- def create_visualization(data: pd.DataFrame, viz_type: str, params: Dict) -> Optional[str]:
71
- """Create visualization based on type and parameters"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
- if viz_type == "scatter":
74
- fig = px.scatter(
75
- data,
76
- x=params["x"],
77
- y=params["y"],
78
- color=params.get("color"),
79
- title=params.get("title", "Scatter Plot")
80
- )
81
- elif viz_type == "histogram":
82
- fig = px.histogram(
83
- data,
84
- x=params["x"],
85
- nbins=params.get("nbins", 30),
86
- title=params.get("title", "Distribution")
87
- )
88
- elif viz_type == "heatmap":
89
- numeric_cols = data.select_dtypes(include=[np.number]).columns
90
- corr = data[numeric_cols].corr()
91
- fig = px.imshow(
92
- corr,
93
- labels=dict(color="Correlation"),
94
- title=params.get("title", "Correlation Heatmap")
95
- )
96
- else:
97
- return None
98
-
99
- # Convert Plotly figure to HTML
100
- return fig.to_html(full_html=False)
101
  except Exception as e:
102
- return {"error": str(e)}
 
103
 
104
- def load_example_data(dataset_name: str = "iris") -> Tuple[pd.DataFrame, Dict]:
105
- """Load example dataset (Iris or Diabetes)"""
106
- try:
107
- if dataset_name == "iris":
108
- data = load_iris()
109
- df = pd.DataFrame(data.data, columns=data.feature_names)
110
- df['target'] = data.target
111
- elif dataset_name == "diabetes":
112
- data = load_diabetes()
113
- df = pd.DataFrame(data.data, columns=data.feature_names)
114
- df['target'] = data.target
115
- else:
116
- return None, {"error": "Invalid dataset name"}
117
 
118
- info = {
119
- "shape": df.shape,
120
- "columns": list(df.columns),
121
- "dtypes": df.dtypes.to_dict(),
122
- "missing_values": df.isnull().sum().to_dict(),
123
- "numeric_columns": list(df.select_dtypes(include=[np.number]).columns),
124
- "categorical_columns": list(df.select_dtypes(exclude=[np.number]).columns)
125
- }
126
-
127
- return df, info
128
- except Exception as e:
129
- return None, {"error": str(e)}
130
 
131
- def query_openai(api_key: str, system_prompt: str, user_prompt: str) -> str:
132
- """Query OpenAI API with the given prompts"""
133
- try:
134
- client = OpenAI(api_key=api_key)
135
- response = client.chat.completions.create(
136
- model="gpt-4", # Use GPT-4 or another valid model
137
- messages=[
138
- {"role": "system", "content": system_prompt},
139
- {"role": "user", "content": user_prompt}
140
- ],
141
- temperature=0.7,
142
- max_tokens=500
143
- )
144
- return response.choices[0].message.content
145
- except Exception as e:
146
- return f"Error querying OpenAI API: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- def create_demo():
149
- # Initialize configuration and state
150
- config = AgentConfig()
151
- analysis_state = AnalysisState()
152
 
153
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
154
- gr.Markdown("# 🔬 Advanced Data Science Agent")
155
-
156
- with gr.Row():
157
- with gr.Column(scale=1):
158
- api_key = gr.Textbox(
159
- label="OpenAI API Key",
160
- type="password",
161
- placeholder="Enter your OpenAI API key"
162
- )
163
-
164
- file_input = gr.File(
165
- label="Upload Data",
166
- file_types=[".csv", ".xlsx", ".json"]
167
- )
168
-
169
- example_btn = gr.Button("Load Example Dataset")
170
-
171
- with gr.Accordion("Visualization Settings", open=False):
172
- viz_type = gr.Dropdown(
173
- choices=["scatter", "histogram", "heatmap"],
174
- label="Visualization Type",
175
- value="scatter"
176
- )
177
- x_axis = gr.Dropdown(label="X-axis", interactive=True)
178
- y_axis = gr.Dropdown(label="Y-axis", interactive=True)
179
- color_column = gr.Dropdown(label="Color Column", interactive=True)
180
 
181
- with gr.Accordion("System Prompt", open=False):
182
- system_prompt = gr.Textbox(
183
- label="System Prompt",
184
- value=config.system_prompt,
185
- lines=5
186
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
- with gr.Column(scale=2):
189
- chatbot = gr.Chatbot(label="Analysis Chat", height=300)
190
- with gr.Row():
191
- chat_input = gr.Textbox(
192
- label="Ask about your data",
193
- placeholder="Type your question here...",
194
- lines=2
195
- )
196
- submit_btn = gr.Button("Send", variant="primary")
197
-
198
- plot_output = gr.HTML(label="Generated Plots")
199
- results_json = gr.JSON(label="Analysis Results")
200
- error_output = gr.Textbox(label="Error Log", visible=False)
201
-
202
- # Event handlers
203
- def handle_file_upload(file):
204
- if file is None:
205
- return None, None, None, "No file uploaded"
206
- df, info = process_uploaded_file(file)
207
- if df is not None:
208
- analysis_state.df = df
209
- analysis_state.current_analysis = info
210
- return info, list(df.columns), list(df.columns), None
211
- return None, None, None, "Failed to load file"
212
-
213
- def handle_example_data():
214
- df, info = load_example_data("iris")
215
- if df is not None:
216
- analysis_state.df = df
217
- analysis_state.current_analysis = info
218
- return info, list(df.columns), list(df.columns), None
219
- return None, None, None, "Failed to load example data"
220
-
221
- def handle_visualization(viz_type, x_axis, y_axis, color_column):
222
- if analysis_state.df is None:
223
- return None, "No data available"
224
- params = {"x": x_axis, "y": y_axis, "color": color_column}
225
- fig_html = create_visualization(analysis_state.df, viz_type, params)
226
- if fig_html is not None:
227
- return fig_html, None
228
- return None, "Failed to create visualization"
229
-
230
- def handle_chat_message(api_key, system_prompt, message, chat_history):
231
- if analysis_state.df is None:
232
- return chat_history + [(message, "Please upload a data file first.")], ""
233
- if not api_key:
234
- return chat_history + [(message, "Please enter your OpenAI API key.")], ""
235
-
236
- # Query OpenAI API
237
- response = query_openai(api_key, system_prompt, message)
238
- return chat_history + [(message, response)], ""
239
-
240
- # Connect event handlers
241
- file_input.change(
242
- handle_file_upload,
243
- inputs=[file_input],
244
- outputs=[results_json, x_axis, y_axis, error_output]
245
- )
246
-
247
- example_btn.click(
248
- handle_example_data,
249
- outputs=[results_json, x_axis, y_axis, error_output]
250
- )
251
-
252
- viz_type.change(
253
- handle_visualization,
254
- inputs=[viz_type, x_axis, y_axis, color_column],
255
- outputs=[plot_output, error_output]
256
- )
257
-
258
- x_axis.change(
259
- handle_visualization,
260
- inputs=[viz_type, x_axis, y_axis, color_column],
261
- outputs=[plot_output, error_output]
262
- )
263
-
264
- y_axis.change(
265
- handle_visualization,
266
- inputs=[viz_type, x_axis, y_axis, color_column],
267
- outputs=[plot_output, error_output]
268
- )
269
-
270
- color_column.change(
271
- handle_visualization,
272
- inputs=[viz_type, x_axis, y_axis, color_column],
273
- outputs=[plot_output, error_output]
274
- )
275
-
276
- submit_btn.click(
277
- handle_chat_message,
278
- inputs=[api_key, system_prompt, chat_input, chatbot],
279
- outputs=[chatbot, chat_input]
280
- )
281
-
282
- return demo
283
 
284
  if __name__ == "__main__":
285
- demo = create_demo()
286
- demo.launch(share=True)
287
- else:
288
- demo = create_demo()
289
- demo.launch(show_api=False)
 
1
+ # app.py
2
+ import streamlit as st
3
+ import google.generativeai as generativeai
4
  import os
5
+ import re
6
+ import json
7
+ import logging
8
  import pandas as pd
 
9
  import plotly.express as px
10
+ import plotly.graph_objects as go
11
+ import seaborn as sns
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from io import StringIO
15
 
16
+ def load_data(uploaded_file):
17
+ try:
18
+ df = pd.read_csv(uploaded_file)
19
+ return df
20
+ except Exception as e:
21
+ st.error(f"Error: {str(e)}")
22
+ return None
 
 
 
 
 
23
 
24
+ def get_numeric_columns(df):
25
+ return df.select_dtypes(include=['float64', 'int64']).columns
 
 
 
 
 
 
26
 
27
+ def get_categorical_columns(df):
28
+ return df.select_dtypes(include=['object', 'category']).columns
 
 
 
29
 
30
+ # Configure logging
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format='%(asctime)s - %(levelname)s - %(message)s',
34
+ handlers=[logging.StreamHandler()]
35
+ )
36
+ logger = logging.getLogger(__name__)
37
 
38
+ def configure_gemini():
39
+ """Configure Google's Gemini AI model."""
 
40
  try:
41
+ from dotenv import load_dotenv
42
+ load_dotenv()
43
+ api_key = os.getenv("GOOGLE_API_KEY")
44
+ if not api_key:
45
+ st.error("Please set your GOOGLE_API_KEY in the .env file")
46
+ return None
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ generativeai.configure(api_key=api_key)
49
+ return generativeai.GenerativeModel('gemini-1.0-pro')
50
  except Exception as e:
51
+ st.error(f"Error configuring Gemini: {str(e)}")
52
+ return None
53
+
54
+ def get_ai_visualization_suggestion(df, user_query):
55
+ """Get AI-powered visualization suggestions based on the data and user query."""
56
+ model = configure_gemini()
57
+ if not model:
58
+ return None
59
+
60
+ # Create a prompt for the AI
61
+ columns_info = {
62
+ 'column_names': list(df.columns),
63
+ 'data_types': {col: str(df[col].dtype) for col in df.columns},
64
+ 'sample_values': {col: df[col].head().tolist() for col in df.columns}
65
+ }
66
+
67
+ prompt = f"""
68
+ Analyze this dataset and the user's query to suggest the best visualization approach:
69
+
70
+ User Query: {user_query}
71
+
72
+ Dataset Information:
73
+ {json.dumps(columns_info, indent=2)}
74
+
75
+ Please suggest:
76
+ 1. The most appropriate type of visualization
77
+ 2. Which columns should be used
78
+ 3. Any data transformations needed
79
+ 4. Visualization parameters (like color schemes, labels, etc.)
80
+
81
+ Format your response as JSON with the following structure:
82
+ {{
83
+ "viz_type": "type of visualization",
84
+ "columns": ["column1", "column2"],
85
+ "transformations": ["transformation1", "transformation2"],
86
+ "parameters": {{
87
+ "param1": "value1",
88
+ "param2": "value2"
89
+ }}
90
+ }}
91
+ """
92
+
93
  try:
94
+ response = model.generate_content(prompt)
95
+ # Extract JSON from response
96
+ suggestion = json.loads(response.text)
97
+ return suggestion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
+ logger.error(f"Error getting AI suggestion: {str(e)}")
100
+ return None
101
 
102
+ def main():
103
+ st.title("📊 AI-Powered Data Visualization Dashboard")
104
+ st.write("Upload your CSV file and explore the data through various visualizations!")
 
 
 
 
 
 
 
 
 
 
105
 
106
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ if uploaded_file is not None:
109
+ df = load_data(uploaded_file)
110
+
111
+ if df is not None:
112
+ st.success("File successfully loaded!")
113
+
114
+ # Basic Data Info
115
+ st.header("📝 Data Overview")
116
+ st.write(f"Number of rows: {df.shape[0]}")
117
+ st.write(f"Number of columns: {df.shape[1]}")
118
+
119
+ # Data Preview
120
+ st.subheader("Data Preview")
121
+ st.dataframe(df.head())
122
+
123
+ # Missing Values Analysis
124
+ st.subheader("Missing Values Analysis")
125
+ missing_data = df.isnull().sum()
126
+ if missing_data.sum() > 0:
127
+ st.write("Missing values by column:")
128
+ st.write(missing_data[missing_data > 0])
129
+ else:
130
+ st.write("No missing values found in the dataset!")
131
+
132
+ # User Query for AI Suggestions
133
+ st.header("🤖 AI-Powered Visualization")
134
+ user_query = st.text_input("Describe what you want to visualize",
135
+ "Show me trends in the data")
136
+
137
+ if st.button("Get AI Suggestion"):
138
+ with st.spinner("Getting AI visualization
139
+
140
+ viz_type = st.selectbox(
141
+ "Choose visualization type",
142
+ ["Scatter Plot", "Line Plot", "Bar Plot", "Histogram", "Box Plot", "Correlation Heatmap"]
143
+ )
144
 
145
+ numeric_columns = get_numeric_columns(df)
146
+ categorical_columns = get_categorical_columns(df)
 
 
147
 
148
+ if viz_type == "Scatter Plot" and len(numeric_columns) >= 2:
149
+ x_col = st.selectbox("Select X axis", numeric_columns)
150
+ y_col = st.selectbox("Select Y axis", numeric_columns)
151
+ color_col = st.selectbox("Select Color variable (optional)",
152
+ ["None"] + list(df.columns))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ if color_col == "None":
155
+ fig = px.scatter(df, x=x_col, y=y_col)
156
+ else:
157
+ fig = px.scatter(df, x=x_col, y=y_col, color=color_col)
158
+ st.plotly_chart(fig)
159
+
160
+ elif viz_type == "Line Plot" and len(numeric_columns) >= 1:
161
+ x_col = st.selectbox("Select X axis", df.columns)
162
+ y_col = st.selectbox("Select Y axis", numeric_columns)
163
+ fig = px.line(df, x=x_col, y=y_col)
164
+ st.plotly_chart(fig)
165
+
166
+ elif viz_type == "Bar Plot":
167
+ x_col = st.selectbox("Select X axis", df.columns)
168
+ y_col = st.selectbox("Select Y axis", numeric_columns)
169
+ fig = px.bar(df, x=x_col, y=y_col)
170
+ st.plotly_chart(fig)
171
+
172
+ elif viz_type == "Histogram" and len(numeric_columns) >= 1:
173
+ col = st.selectbox("Select column", numeric_columns)
174
+ bins = st.slider("Number of bins", min_value=5, max_value=100, value=30)
175
+ fig = px.histogram(df, x=col, nbins=bins)
176
+ st.plotly_chart(fig)
177
+
178
+ elif viz_type == "Box Plot" and len(numeric_columns) >= 1:
179
+ y_col = st.selectbox("Select column for box plot", numeric_columns)
180
+ x_col = st.selectbox("Select grouping variable (optional)",
181
+ ["None"] + list(categorical_columns))
182
 
183
+ if x_col == "None":
184
+ fig = px.box(df, y=y_col)
185
+ else:
186
+ fig = px.box(df, x=x_col, y=y_col)
187
+ st.plotly_chart(fig)
188
+
189
+ elif viz_type == "Correlation Heatmap" and len(numeric_columns) >= 2:
190
+ corr_matrix = df[numeric_columns].corr()
191
+ fig = px.imshow(corr_matrix,
192
+ labels=dict(color="Correlation"),
193
+ x=corr_matrix.columns,
194
+ y=corr_matrix.columns)
195
+ st.plotly_chart(fig)
196
+
197
+ # Data Summary
198
+ st.header("📊 Data Summary")
199
+ if len(numeric_columns) > 0:
200
+ st.subheader("Numerical Columns Summary")
201
+ st.write(df[numeric_columns].describe())
202
+
203
+ if len(categorical_columns) > 0:
204
+ st.subheader("Categorical Columns Summary")
205
+ for col in categorical_columns:
206
+ st.write(f"\nValue counts for {col}:")
207
+ st.write(df[col].value_counts())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
+ main()