Osnly commited on
Commit
29d5a95
·
verified ·
1 Parent(s): 802fc25

Update src/visualize.py

Browse files
Files changed (1) hide show
  1. src/visualize.py +40 -40
src/visualize.py CHANGED
@@ -1,40 +1,40 @@
1
- # visualize.py
2
- import pandas as pd
3
- import seaborn as sns
4
- import matplotlib.pyplot as plt
5
- import os
6
-
7
- def plot_column_distributions(df: pd.DataFrame, output_dir="charts"):
8
- os.makedirs(output_dir, exist_ok=True)
9
-
10
- for col in df.columns:
11
- plt.figure(figsize=(6, 4))
12
-
13
- if pd.api.types.is_numeric_dtype(df[col]):
14
- sns.histplot(df[col].dropna(), kde=True)
15
- plt.title(f"Distribution of {col}")
16
- elif pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == object:
17
- top_vals = df[col].value_counts().nlargest(10)
18
- sns.barplot(x=top_vals.values, y=top_vals.index)
19
- plt.title(f"Top categories in {col}")
20
- else:
21
- continue
22
-
23
- plt.tight_layout()
24
- plt.savefig(f"{output_dir}/{col}.png")
25
- plt.close()
26
-
27
- def plot_relationships(df, target_col='income', output_dir="charts"):
28
- os.makedirs(output_dir, exist_ok=True)
29
-
30
- for col in df.columns:
31
- if col == target_col:
32
- continue
33
-
34
- if pd.api.types.is_numeric_dtype(df[col]) and target_col in df.columns:
35
- plt.figure(figsize=(6, 4))
36
- sns.boxplot(x=target_col, y=col, data=df)
37
- plt.title(f"{col} by {target_col}")
38
- plt.tight_layout()
39
- plt.savefig(f"{output_dir}/{col}_by_{target_col}.png")
40
- plt.close()
 
1
+ # visualize.py
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+
7
+ def plot_column_distributions(df: pd.DataFrame, charts_dir="/tmp/charts"):
8
+ os.makedirs(output_dir, exist_ok=True)
9
+
10
+ for col in df.columns:
11
+ plt.figure(figsize=(6, 4))
12
+
13
+ if pd.api.types.is_numeric_dtype(df[col]):
14
+ sns.histplot(df[col].dropna(), kde=True)
15
+ plt.title(f"Distribution of {col}")
16
+ elif pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == object:
17
+ top_vals = df[col].value_counts().nlargest(10)
18
+ sns.barplot(x=top_vals.values, y=top_vals.index)
19
+ plt.title(f"Top categories in {col}")
20
+ else:
21
+ continue
22
+
23
+ plt.tight_layout()
24
+ plt.savefig(f"{output_dir}/{col}.png")
25
+ plt.close()
26
+
27
+ def plot_relationships(df, target_col='income', charts_dir="/tmp/charts"):
28
+ os.makedirs(output_dir, exist_ok=True)
29
+
30
+ for col in df.columns:
31
+ if col == target_col:
32
+ continue
33
+
34
+ if pd.api.types.is_numeric_dtype(df[col]) and target_col in df.columns:
35
+ plt.figure(figsize=(6, 4))
36
+ sns.boxplot(x=target_col, y=col, data=df)
37
+ plt.title(f"{col} by {target_col}")
38
+ plt.tight_layout()
39
+ plt.savefig(f"{output_dir}/{col}_by_{target_col}.png")
40
+ plt.close()