sancho10 commited on
Commit
4bf23fc
·
verified ·
1 Parent(s): 6b89802

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +148 -146
main.py CHANGED
@@ -1,146 +1,148 @@
1
- import pandas as pd
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- import joblib
6
- from sklearn.model_selection import train_test_split
7
- from sklearn.preprocessing import LabelEncoder, StandardScaler
8
- from sklearn.ensemble import RandomForestClassifier, VotingClassifier
9
- from sklearn.linear_model import LogisticRegression
10
- from sklearn.svm import SVC
11
- from sklearn.metrics import (
12
- classification_report, confusion_matrix, accuracy_score,
13
- ConfusionMatrixDisplay
14
- )
15
- import warnings
16
- warnings.filterwarnings('ignore')
17
-
18
- # =====================
19
- # 1. Load Dataset
20
- # =====================
21
- df = pd.read_csv("ARTI_Main_Data.csv")
22
-
23
- # Handle missing values
24
- df['Bacterial_Infection'] = df['Bacterial_Infection'].fillna("None")
25
- df['Viral_Infection'] = df['Viral_Infection'].fillna("None")
26
-
27
- # =====================
28
- # 2. Set up features and multi-class target
29
- # =====================
30
- features = [
31
- 'Age', 'Sex', 'Socioeconomic_Status', 'Vitamin_D_Level_ng/ml',
32
- 'Vitamin_D_Status', 'Vitamin_D_Supplemented', 'Bacterial_Infection',
33
- 'Viral_Infection', 'Co_Infection', 'IL6_pg/ml', 'IL8_pg/ml'
34
- ]
35
- target = 'ARTI_Severity'
36
-
37
- # =====================
38
- # 3. Encode features and target
39
- # =====================
40
- df_encoded = df[features].copy()
41
- cat_cols = df_encoded.select_dtypes(include=['object']).columns
42
- label_encoders = {}
43
-
44
- for col in cat_cols:
45
- le = LabelEncoder()
46
- df_encoded[col] = le.fit_transform(df_encoded[col])
47
- label_encoders[col] = le
48
-
49
- # Encode target (multi-class)
50
- target_encoder = LabelEncoder()
51
- df['ARTI_Severity_Label'] = target_encoder.fit_transform(df[target])
52
- y = df['ARTI_Severity_Label']
53
-
54
- # =====================
55
- # 4. Scale numerical features
56
- # =====================
57
- scaler = StandardScaler()
58
- X_scaled = scaler.fit_transform(df_encoded)
59
-
60
- # =====================
61
- # 5. Train-test split
62
- # =====================
63
- X_train, X_test, y_train, y_test = train_test_split(
64
- X_scaled, y, test_size=0.2, random_state=42, stratify=y
65
- )
66
-
67
- # =====================
68
- # 6. Define Models
69
- # =====================
70
- log_reg = LogisticRegression(max_iter=500)
71
- rf = RandomForestClassifier(n_estimators=100, random_state=42)
72
- svm = SVC(probability=True)
73
-
74
- # Voting classifier
75
- voting_model = VotingClassifier(estimators=[
76
- ('lr', log_reg),
77
- ('rf', rf),
78
- ('svm', svm)
79
- ], voting='hard')
80
-
81
- # =====================
82
- # 7. Train Model
83
- # =====================
84
- voting_model.fit(X_train, y_train)
85
-
86
- # Save model and preprocessors
87
- joblib.dump(voting_model, "voting_model_multiclass.pkl")
88
- joblib.dump(scaler, "scaler.pkl")
89
- joblib.dump(label_encoders, "feature_label_encoders.pkl")
90
- joblib.dump(target_encoder, "target_label_encoder.pkl")
91
-
92
- # =====================
93
- # 8. Evaluation
94
- # =====================
95
- y_pred = voting_model.predict(X_test)
96
-
97
- print("\n📊 Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
98
- print("\n📑 Classification Report:\n", classification_report(y_test, y_pred, target_names=target_encoder.classes_))
99
- print("\n Accuracy Score:", accuracy_score(y_test, y_pred))
100
-
101
- # =====================
102
- # 9. Visualizations
103
- # =====================
104
-
105
- # 1. Confusion Matrix
106
- disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred),
107
- display_labels=target_encoder.classes_)
108
- disp.plot(cmap=plt.cm.Blues)
109
- plt.title("Confusion Matrix")
110
- plt.savefig("confusion_matrix_multiclass.png")
111
- plt.show()
112
-
113
- # 2. Feature Importance (fit rf separately for this)
114
- rf.fit(X_train, y_train)
115
- plt.figure(figsize=(8, 5))
116
- importances = rf.feature_importances_
117
- indices = np.argsort(importances)[::-1]
118
- feature_names = df_encoded.columns
119
-
120
- sns.barplot(x=importances[indices], y=np.array(feature_names)[indices], palette='viridis')
121
- plt.title("Feature Importance (Random Forest)")
122
- plt.xlabel("Importance Score")
123
- plt.ylabel("Features")
124
- plt.tight_layout()
125
- plt.savefig("feature_importance_rf.png")
126
- plt.show()
127
-
128
- # 3. Class Distribution
129
- plt.figure(figsize=(6, 4))
130
- sns.countplot(x=df[target], palette='pastel')
131
- plt.title("Distribution of ARTI Severity Classes")
132
- plt.xlabel("ARTI Severity")
133
- plt.ylabel("Count")
134
- plt.tight_layout()
135
- plt.savefig("class_distribution.png")
136
- plt.show()
137
-
138
- # 4. Actual vs Predicted Comparison
139
- plt.figure(figsize=(8, 5))
140
- sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt="d", cmap="YlGnBu",
141
- xticklabels=target_encoder.classes_, yticklabels=target_encoder.classes_)
142
- plt.xlabel("Predicted")
143
- plt.ylabel("Actual")
144
- plt.title("Actual vs Predicted Heatmap")
145
- plt.savefig("actual_vs_predicted_heatmap.png")
146
- plt.show()
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import joblib
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
8
+ from sklearn.ensemble import RandomForestClassifier, VotingClassifier, GradientBoostingClassifier
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.svm import SVC
11
+ from sklearn.metrics import (
12
+ classification_report, confusion_matrix, accuracy_score,
13
+ ConfusionMatrixDisplay
14
+ )
15
+ import warnings
16
+ warnings.filterwarnings('ignore')
17
+
18
+ # =====================
19
+ # 1. Load Dataset
20
+ # =====================
21
+ df = pd.read_csv("ARTI_Main_Data.csv")
22
+
23
+ # Handle missing values
24
+ df['Bacterial_Infection'] = df['Bacterial_Infection'].fillna("None")
25
+ df['Viral_Infection'] = df['Viral_Infection'].fillna("None")
26
+
27
+ # =====================
28
+ # 2. Set up features and multi-class target
29
+ # =====================
30
+ features = [
31
+ 'Age', 'Sex', 'Socioeconomic_Status', 'Vitamin_D_Level_ng/ml',
32
+ 'Vitamin_D_Status', 'Vitamin_D_Supplemented', 'Bacterial_Infection',
33
+ 'Viral_Infection', 'Co_Infection', 'IL6_pg/ml', 'IL8_pg/ml'
34
+ ]
35
+ target = 'ARTI_Severity'
36
+
37
+ # =====================
38
+ # 3. Encode features and target
39
+ # =====================
40
+ df_encoded = df[features].copy()
41
+ cat_cols = df_encoded.select_dtypes(include=['object']).columns
42
+ label_encoders = {}
43
+
44
+ for col in cat_cols:
45
+ le = LabelEncoder()
46
+ df_encoded[col] = le.fit_transform(df_encoded[col])
47
+ label_encoders[col] = le
48
+
49
+ # Encode target (multi-class)
50
+ target_encoder = LabelEncoder()
51
+ df['ARTI_Severity_Label'] = target_encoder.fit_transform(df[target])
52
+ y = df['ARTI_Severity_Label']
53
+
54
+ # =====================
55
+ # 4. Scale numerical features
56
+ # =====================
57
+ scaler = StandardScaler()
58
+ X_scaled = scaler.fit_transform(df_encoded)
59
+
60
+ # =====================
61
+ # 5. Train-test split
62
+ # =====================
63
+ X_train, X_test, y_train, y_test = train_test_split(
64
+ X_scaled, y, test_size=0.2, random_state=42, stratify=y
65
+ )
66
+
67
+ # =====================
68
+ # 6. Define Models with tuned parameters
69
+ # =====================
70
+ log_reg = LogisticRegression(max_iter=1000, class_weight='balanced', C=1.0)
71
+ rf = RandomForestClassifier(n_estimators=200, max_depth=10, class_weight='balanced', random_state=42)
72
+ svm = SVC(probability=True, kernel='rbf', C=1.5, class_weight='balanced')
73
+ gb = GradientBoostingClassifier(n_estimators=150, learning_rate=0.1, random_state=42)
74
+
75
+ # Voting classifier with soft voting
76
+ voting_model = VotingClassifier(estimators=[
77
+ ('lr', log_reg),
78
+ ('rf', rf),
79
+ ('svm', svm),
80
+ ('gb', gb)
81
+ ], voting='soft')
82
+
83
+ # =====================
84
+ # 7. Train Model
85
+ # =====================
86
+ voting_model.fit(X_train, y_train)
87
+
88
+ # Save model and preprocessors
89
+ joblib.dump(voting_model, "voting_model_multiclass.pkl")
90
+ joblib.dump(scaler, "scaler.pkl")
91
+ joblib.dump(label_encoders, "feature_label_encoders.pkl")
92
+ joblib.dump(target_encoder, "target_label_encoder.pkl")
93
+
94
+ # =====================
95
+ # 8. Evaluation
96
+ # =====================
97
+ y_pred = voting_model.predict(X_test)
98
+
99
+ print("\n📊 Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
100
+ print("\n📑 Classification Report:\n", classification_report(y_test, y_pred, target_names=target_encoder.classes_))
101
+ print("\n✅ Accuracy Score:", accuracy_score(y_test, y_pred))
102
+
103
+ # =====================
104
+ # 9. Visualizations
105
+ # =====================
106
+
107
+ # 1. Confusion Matrix
108
+ disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred),
109
+ display_labels=target_encoder.classes_)
110
+ disp.plot(cmap=plt.cm.Blues)
111
+ plt.title("Confusion Matrix")
112
+ plt.savefig("confusion_matrix_multiclass.png")
113
+ plt.show()
114
+
115
+ # 2. Feature Importance (Random Forest)
116
+ rf.fit(X_train, y_train)
117
+ plt.figure(figsize=(8, 5))
118
+ importances = rf.feature_importances_
119
+ indices = np.argsort(importances)[::-1]
120
+ feature_names = df_encoded.columns
121
+
122
+ sns.barplot(x=importances[indices], y=np.array(feature_names)[indices], palette='viridis')
123
+ plt.title("Feature Importance (Random Forest)")
124
+ plt.xlabel("Importance Score")
125
+ plt.ylabel("Features")
126
+ plt.tight_layout()
127
+ plt.savefig("feature_importance_rf.png")
128
+ plt.show()
129
+
130
+ # 3. Class Distribution
131
+ plt.figure(figsize=(6, 4))
132
+ sns.countplot(x=df[target], palette='pastel')
133
+ plt.title("Distribution of ARTI Severity Classes")
134
+ plt.xlabel("ARTI Severity")
135
+ plt.ylabel("Count")
136
+ plt.tight_layout()
137
+ plt.savefig("class_distribution.png")
138
+ plt.show()
139
+
140
+ # 4. Actual vs Predicted Comparison
141
+ plt.figure(figsize=(8, 5))
142
+ sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt="d", cmap="YlGnBu",
143
+ xticklabels=target_encoder.classes_, yticklabels=target_encoder.classes_)
144
+ plt.xlabel("Predicted")
145
+ plt.ylabel("Actual")
146
+ plt.title("Actual vs Predicted Heatmap")
147
+ plt.savefig("actual_vs_predicted_heatmap.png")
148
+ plt.show()