Update app.py
Browse files
app.py
CHANGED
@@ -2,249 +2,261 @@ import os
|
|
2 |
import requests
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
7 |
-
import
|
8 |
-
|
|
|
9 |
from sklearn.model_selection import train_test_split
|
10 |
-
from sklearn.
|
11 |
-
from sklearn.
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
def extended_analysis(df):
|
49 |
-
"""
|
50 |
-
Does correlation heatmap, bar plot for 'Career', and logistic regression
|
51 |
-
if 'Career' has multiple categories. Returns (list_of_image_paths, info_string).
|
52 |
-
"""
|
53 |
-
output_paths = []
|
54 |
-
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
55 |
-
|
56 |
-
# 1) Correlation Heatmap
|
57 |
-
if len(numeric_cols) > 1:
|
58 |
-
corr = df[numeric_cols].corr()
|
59 |
-
plt.figure(figsize=(8, 6))
|
60 |
-
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
|
61 |
-
plt.title("Correlation Heatmap")
|
62 |
-
heatmap_path = "heatmap.png"
|
63 |
-
plt.savefig(heatmap_path)
|
64 |
-
plt.close()
|
65 |
-
output_paths.append(heatmap_path)
|
66 |
-
|
67 |
-
# 2) Bar Plot for 'Career'
|
68 |
-
if "Career" in df.columns:
|
69 |
-
plt.figure(figsize=(8, 5))
|
70 |
-
career_counts = df["Career"].value_counts()
|
71 |
-
sns.barplot(x=career_counts.index, y=career_counts.values)
|
72 |
-
plt.title("Distribution of Careers")
|
73 |
-
plt.xlabel("Career")
|
74 |
-
plt.ylabel("Count")
|
75 |
-
plt.xticks(rotation=45, ha="right")
|
76 |
-
barplot_path = "career_distribution.png"
|
77 |
-
plt.savefig(barplot_path)
|
78 |
-
plt.close()
|
79 |
-
output_paths.append(barplot_path)
|
80 |
-
|
81 |
-
# 3) Simple Logistic Regression
|
82 |
-
if "Career" in df.columns and len(numeric_cols) > 0:
|
83 |
-
le = LabelEncoder()
|
84 |
-
df["Career_encoded"] = le.fit_transform(df["Career"])
|
85 |
-
X = df[numeric_cols].fillna(0)
|
86 |
-
y = df["Career_encoded"]
|
87 |
-
if len(np.unique(y)) > 1:
|
88 |
-
X_train, X_test, y_train, y_test = train_test_split(
|
89 |
-
X, y, test_size=0.2, random_state=42
|
90 |
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
else:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
if
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
def on_user_message(message, df, chat_history, api_key):
|
198 |
-
"""
|
199 |
-
Called when user sends a message. Handle chat + analysis. Return new chat messages.
|
200 |
-
"""
|
201 |
-
if not message.strip():
|
202 |
-
return chat_history # ignore empty
|
203 |
-
updated_history = handle_chat(message, df, chat_history, api_key)
|
204 |
-
return updated_history
|
205 |
-
|
206 |
-
user_message.submit(
|
207 |
-
fn=on_user_message,
|
208 |
-
inputs=[user_message, df_state, chat_state, api_key_box],
|
209 |
-
outputs=chat_state
|
210 |
-
).then(
|
211 |
-
# After updating chat_state, reflect it in the chatbot
|
212 |
-
fn=lambda messages: messages,
|
213 |
-
inputs=chat_state,
|
214 |
-
outputs=chatbot
|
215 |
-
).then(
|
216 |
-
fn=lambda: "",
|
217 |
-
outputs=user_message
|
218 |
-
)
|
219 |
-
|
220 |
-
# Button to send message
|
221 |
-
send_btn = gr.Button("Send")
|
222 |
-
send_btn.click(
|
223 |
-
fn=on_user_message,
|
224 |
-
inputs=[user_message, df_state, chat_state, api_key_box],
|
225 |
-
outputs=chat_state
|
226 |
-
).then(
|
227 |
-
fn=lambda messages: messages,
|
228 |
-
inputs=chat_state,
|
229 |
-
outputs=chatbot
|
230 |
-
).then(
|
231 |
-
fn=lambda: "",
|
232 |
-
outputs=user_message
|
233 |
-
)
|
234 |
-
|
235 |
-
# Clear chat button
|
236 |
-
clear_btn = gr.Button("Clear Chat")
|
237 |
-
def clear_chat():
|
238 |
-
return [], []
|
239 |
-
clear_btn.click(
|
240 |
-
fn=clear_chat,
|
241 |
-
inputs=[],
|
242 |
-
outputs=[chat_state, chatbot]
|
243 |
-
)
|
244 |
-
|
245 |
-
return demo
|
246 |
-
|
247 |
-
demo = create_demo()
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
-
|
|
|
|
|
|
2 |
import requests
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
import matplotlib.pyplot as plt
|
7 |
import seaborn as sns
|
8 |
+
from typing import Dict, List, Tuple, Optional
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
11 |
from sklearn.model_selection import train_test_split
|
12 |
+
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score
|
13 |
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
14 |
+
from sklearn.impute import SimpleImputer
|
15 |
+
import statsmodels.api as sm
|
16 |
+
import plotly.express as px
|
17 |
+
import plotly.graph_objects as go
|
18 |
+
from scipy import stats
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class AnalysisConfig:
|
22 |
+
"""Configuration for analysis parameters"""
|
23 |
+
max_iterations: int = 5
|
24 |
+
min_samples_for_analysis: int = 30
|
25 |
+
correlation_threshold: float = 0.7
|
26 |
+
max_categories_for_viz: int = 10
|
27 |
+
significance_level: float = 0.05
|
28 |
+
|
29 |
+
class DataAnalyzer:
|
30 |
+
"""Intelligent data analysis agent that determines appropriate visualizations and analyses"""
|
31 |
+
|
32 |
+
def __init__(self, api_key: str):
|
33 |
+
self.api_key = api_key
|
34 |
+
self.config = AnalysisConfig()
|
35 |
+
self.current_iteration = 0
|
36 |
+
self.analysis_results = []
|
37 |
+
|
38 |
+
def call_gpt4o_mini(self, prompt: str) -> str:
|
39 |
+
"""Call GPT-4o-mini API with proper error handling"""
|
40 |
+
try:
|
41 |
+
headers = {
|
42 |
+
"Authorization": f"Bearer {self.api_key}",
|
43 |
+
"Content-Type": "application/json"
|
44 |
+
}
|
45 |
+
response = requests.post(
|
46 |
+
"https://api.gpt4o-mini.example.com/v1/chat", # Replace with actual endpoint
|
47 |
+
json={"prompt": prompt, "max_tokens": 500, "temperature": 0.7},
|
48 |
+
headers=headers,
|
49 |
+
timeout=15
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
+
response.raise_for_status()
|
52 |
+
return response.json()["choices"][0]["text"]
|
53 |
+
except Exception as e:
|
54 |
+
return f"API Error: {str(e)}"
|
55 |
+
|
56 |
+
def analyze_data_types(self, df: pd.DataFrame) -> Dict:
|
57 |
+
"""Analyze data types and basic statistics of the DataFrame"""
|
58 |
+
analysis = {
|
59 |
+
"numeric_cols": df.select_dtypes(include=['int64', 'float64']).columns.tolist(),
|
60 |
+
"categorical_cols": df.select_dtypes(include=['object', 'category']).columns.tolist(),
|
61 |
+
"temporal_cols": df.select_dtypes(include=['datetime64']).columns.tolist(),
|
62 |
+
"missing_values": df.isnull().sum().to_dict(),
|
63 |
+
"unique_counts": df.nunique().to_dict()
|
64 |
+
}
|
65 |
+
return analysis
|
66 |
+
|
67 |
+
def create_visualization(self, df: pd.DataFrame, viz_type: str, columns: List[str]) -> str:
|
68 |
+
"""Create and save visualization based on data types and relationships"""
|
69 |
+
plt.figure(figsize=(10, 6))
|
70 |
+
|
71 |
+
if viz_type == "correlation":
|
72 |
+
sns.heatmap(df[columns].corr(), annot=True, cmap='coolwarm')
|
73 |
+
plt.title("Correlation Matrix")
|
74 |
+
elif viz_type == "distribution":
|
75 |
+
for col in columns:
|
76 |
+
sns.histplot(data=df, x=col, kde=True)
|
77 |
+
plt.title(f"Distribution of {col}")
|
78 |
+
elif viz_type == "boxplot":
|
79 |
+
sns.boxplot(data=df[columns])
|
80 |
+
plt.title("Box Plot of Numeric Variables")
|
81 |
+
|
82 |
+
output_path = f"viz_{self.current_iteration}.png"
|
83 |
+
plt.savefig(output_path)
|
84 |
+
plt.close()
|
85 |
+
return output_path
|
86 |
+
|
87 |
+
def perform_statistical_tests(self, df: pd.DataFrame, data_types: Dict) -> Dict:
|
88 |
+
"""Perform relevant statistical tests based on data types"""
|
89 |
+
results = {}
|
90 |
+
|
91 |
+
# Normality tests for numeric columns
|
92 |
+
for col in data_types["numeric_cols"]:
|
93 |
+
if len(df[col].dropna()) > 3:
|
94 |
+
stat, p_value = stats.normaltest(df[col].dropna())
|
95 |
+
results[f"normality_{col}"] = {
|
96 |
+
"statistic": stat,
|
97 |
+
"p_value": p_value,
|
98 |
+
"is_normal": p_value > self.config.significance_level
|
99 |
+
}
|
100 |
+
|
101 |
+
# Chi-square tests for categorical columns
|
102 |
+
for col1 in data_types["categorical_cols"]:
|
103 |
+
for col2 in data_types["categorical_cols"]:
|
104 |
+
if col1 < col2:
|
105 |
+
contingency = pd.crosstab(df[col1], df[col2])
|
106 |
+
chi2, p_value, _, _ = stats.chi2_contingency(contingency)
|
107 |
+
results[f"chi2_{col1}_{col2}"] = {
|
108 |
+
"statistic": chi2,
|
109 |
+
"p_value": p_value,
|
110 |
+
"is_significant": p_value < self.config.significance_level
|
111 |
+
}
|
112 |
+
|
113 |
+
return results
|
114 |
+
|
115 |
+
def train_predictive_model(self, df: pd.DataFrame, target_col: str) -> Tuple[float, str]:
|
116 |
+
"""Train and evaluate a predictive model based on data characteristics"""
|
117 |
+
X = df.drop(columns=[target_col])
|
118 |
+
y = df[target_col]
|
119 |
+
|
120 |
+
# Preprocessing
|
121 |
+
numeric_transformer = Pipeline([
|
122 |
+
('imputer', SimpleImputer(strategy='median')),
|
123 |
+
('scaler', StandardScaler())
|
124 |
+
])
|
125 |
+
|
126 |
+
categorical_transformer = Pipeline([
|
127 |
+
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
128 |
+
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
129 |
+
])
|
130 |
+
|
131 |
+
preprocessor = ColumnTransformer(
|
132 |
+
transformers=[
|
133 |
+
('num', numeric_transformer, X.select_dtypes(include=['int64', 'float64']).columns),
|
134 |
+
('cat', categorical_transformer, X.select_dtypes(include=['object']).columns)
|
135 |
+
])
|
136 |
+
|
137 |
+
if len(np.unique(y)) <= 5: # Classification
|
138 |
+
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
139 |
+
metric = 'accuracy'
|
140 |
+
else: # Regression
|
141 |
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
142 |
+
metric = 'r2'
|
143 |
+
|
144 |
+
pipeline = Pipeline([
|
145 |
+
('preprocessor', preprocessor),
|
146 |
+
('model', model)
|
147 |
+
])
|
148 |
+
|
149 |
+
# Train and evaluate
|
150 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
151 |
+
pipeline.fit(X_train, y_train)
|
152 |
+
y_pred = pipeline.predict(X_test)
|
153 |
+
|
154 |
+
if metric == 'accuracy':
|
155 |
+
score = accuracy_score(y_test, y_pred)
|
156 |
else:
|
157 |
+
score = r2_score(y_test, y_pred)
|
158 |
+
|
159 |
+
return score, metric
|
160 |
+
|
161 |
+
class GradioInterface:
|
162 |
+
"""Gradio interface for the data analysis agent"""
|
163 |
+
|
164 |
+
def __init__(self):
|
165 |
+
self.analyzer = None
|
166 |
+
self.df = None
|
167 |
+
|
168 |
+
def create_interface(self):
|
169 |
+
with gr.Blocks() as demo:
|
170 |
+
gr.Markdown("# Intelligent Data Analysis Agent")
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
api_key = gr.Textbox(label="GPT-4o-mini API Key", type="password")
|
174 |
+
file_input = gr.File(label="Upload CSV file")
|
175 |
+
|
176 |
+
with gr.Row():
|
177 |
+
analysis_notes = gr.Textbox(label="Analysis Notes (Optional)",
|
178 |
+
placeholder="Any specific analysis preferences...")
|
179 |
+
|
180 |
+
with gr.Row():
|
181 |
+
analyze_btn = gr.Button("Analyze Data")
|
182 |
+
clear_btn = gr.Button("Clear")
|
183 |
+
|
184 |
+
output_text = gr.Markdown()
|
185 |
+
output_gallery = gr.Gallery()
|
186 |
+
|
187 |
+
def analyze(api_key, file, notes):
|
188 |
+
if not api_key or not file:
|
189 |
+
return "Please provide both API key and data file.", None
|
190 |
+
|
191 |
+
try:
|
192 |
+
self.df = pd.read_csv(file.name)
|
193 |
+
self.analyzer = DataAnalyzer(api_key)
|
194 |
+
|
195 |
+
# Get AI suggestions for analysis
|
196 |
+
prompt = f"Data columns: {list(self.df.columns)}\nUser notes: {notes}\nSuggest appropriate analyses and visualizations."
|
197 |
+
ai_suggestions = self.analyzer.call_gpt4o_mini(prompt)
|
198 |
+
|
199 |
+
# Perform analysis
|
200 |
+
data_types = self.analyzer.analyze_data_types(self.df)
|
201 |
+
stats_results = self.analyzer.perform_statistical_tests(self.df, data_types)
|
202 |
+
|
203 |
+
# Create visualizations
|
204 |
+
viz_paths = []
|
205 |
+
for viz_type in ["correlation", "distribution", "boxplot"]:
|
206 |
+
if data_types["numeric_cols"]:
|
207 |
+
path = self.analyzer.create_visualization(
|
208 |
+
self.df, viz_type, data_types["numeric_cols"]
|
209 |
+
)
|
210 |
+
viz_paths.append(path)
|
211 |
+
|
212 |
+
# Generate summary
|
213 |
+
summary = f"""
|
214 |
+
## Data Analysis Results
|
215 |
+
|
216 |
+
### AI Suggestions
|
217 |
+
{ai_suggestions}
|
218 |
+
|
219 |
+
### Basic Statistics
|
220 |
+
- Rows: {len(self.df)}
|
221 |
+
- Columns: {len(self.df.columns)}
|
222 |
+
- Missing Values: {sum(data_types['missing_values'].values())}
|
223 |
+
|
224 |
+
### Statistical Tests
|
225 |
+
{self._format_stats_results(stats_results)}
|
226 |
+
"""
|
227 |
+
|
228 |
+
return summary, viz_paths
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
return f"Error during analysis: {str(e)}", None
|
232 |
+
|
233 |
+
analyze_btn.click(
|
234 |
+
analyze,
|
235 |
+
inputs=[api_key, file_input, analysis_notes],
|
236 |
+
outputs=[output_text, output_gallery]
|
237 |
+
)
|
238 |
+
|
239 |
+
clear_btn.click(
|
240 |
+
lambda: (None, None),
|
241 |
+
outputs=[output_text, output_gallery]
|
242 |
+
)
|
243 |
+
|
244 |
+
return demo
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def _format_stats_results(results: Dict) -> str:
|
248 |
+
"""Format statistical results for display"""
|
249 |
+
formatted = []
|
250 |
+
for test_name, result in results.items():
|
251 |
+
if "normality" in test_name:
|
252 |
+
formatted.append(f"- {test_name}: {'Normal' if result['is_normal'] else 'Non-normal'} "
|
253 |
+
f"(p={result['p_value']:.4f})")
|
254 |
+
elif "chi2" in test_name:
|
255 |
+
formatted.append(f"- {test_name}: {'Significant' if result['is_significant'] else 'Not significant'} "
|
256 |
+
f"(p={result['p_value']:.4f})")
|
257 |
+
return "\n".join(formatted)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if __name__ == "__main__":
|
260 |
+
interface = GradioInterface()
|
261 |
+
demo = interface.create_interface()
|
262 |
+
demo.launch(share=True)
|