sancho10's picture
Update main.py
4bf23fc verified
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import (
classification_report, confusion_matrix, accuracy_score,
ConfusionMatrixDisplay
)
import warnings
warnings.filterwarnings('ignore')
# =====================
# 1. Load Dataset
# =====================
df = pd.read_csv("ARTI_Main_Data.csv")
# Handle missing values
df['Bacterial_Infection'] = df['Bacterial_Infection'].fillna("None")
df['Viral_Infection'] = df['Viral_Infection'].fillna("None")
# =====================
# 2. Set up features and multi-class target
# =====================
features = [
'Age', 'Sex', 'Socioeconomic_Status', 'Vitamin_D_Level_ng/ml',
'Vitamin_D_Status', 'Vitamin_D_Supplemented', 'Bacterial_Infection',
'Viral_Infection', 'Co_Infection', 'IL6_pg/ml', 'IL8_pg/ml'
]
target = 'ARTI_Severity'
# =====================
# 3. Encode features and target
# =====================
df_encoded = df[features].copy()
cat_cols = df_encoded.select_dtypes(include=['object']).columns
label_encoders = {}
for col in cat_cols:
le = LabelEncoder()
df_encoded[col] = le.fit_transform(df_encoded[col])
label_encoders[col] = le
# Encode target (multi-class)
target_encoder = LabelEncoder()
df['ARTI_Severity_Label'] = target_encoder.fit_transform(df[target])
y = df['ARTI_Severity_Label']
# =====================
# 4. Scale numerical features
# =====================
scaler = StandardScaler()
X_scaled = scaler.fit_transform(df_encoded)
# =====================
# 5. Train-test split
# =====================
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=42, stratify=y
)
# =====================
# 6. Define Models with tuned parameters
# =====================
log_reg = LogisticRegression(max_iter=1000, class_weight='balanced', C=1.0)
rf = RandomForestClassifier(n_estimators=200, max_depth=10, class_weight='balanced', random_state=42)
svm = SVC(probability=True, kernel='rbf', C=1.5, class_weight='balanced')
gb = GradientBoostingClassifier(n_estimators=150, learning_rate=0.1, random_state=42)
# Voting classifier with soft voting
voting_model = VotingClassifier(estimators=[
('lr', log_reg),
('rf', rf),
('svm', svm),
('gb', gb)
], voting='soft')
# =====================
# 7. Train Model
# =====================
voting_model.fit(X_train, y_train)
# Save model and preprocessors
joblib.dump(voting_model, "voting_model_multiclass.pkl")
joblib.dump(scaler, "scaler.pkl")
joblib.dump(label_encoders, "feature_label_encoders.pkl")
joblib.dump(target_encoder, "target_label_encoder.pkl")
# =====================
# 8. Evaluation
# =====================
y_pred = voting_model.predict(X_test)
print("\nπŸ“Š Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nπŸ“‘ Classification Report:\n", classification_report(y_test, y_pred, target_names=target_encoder.classes_))
print("\nβœ… Accuracy Score:", accuracy_score(y_test, y_pred))
# =====================
# 9. Visualizations
# =====================
# 1. Confusion Matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred),
display_labels=target_encoder.classes_)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix_multiclass.png")
plt.show()
# 2. Feature Importance (Random Forest)
rf.fit(X_train, y_train)
plt.figure(figsize=(8, 5))
importances = rf.feature_importances_
indices = np.argsort(importances)[::-1]
feature_names = df_encoded.columns
sns.barplot(x=importances[indices], y=np.array(feature_names)[indices], palette='viridis')
plt.title("Feature Importance (Random Forest)")
plt.xlabel("Importance Score")
plt.ylabel("Features")
plt.tight_layout()
plt.savefig("feature_importance_rf.png")
plt.show()
# 3. Class Distribution
plt.figure(figsize=(6, 4))
sns.countplot(x=df[target], palette='pastel')
plt.title("Distribution of ARTI Severity Classes")
plt.xlabel("ARTI Severity")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("class_distribution.png")
plt.show()
# 4. Actual vs Predicted Comparison
plt.figure(figsize=(8, 5))
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt="d", cmap="YlGnBu",
xticklabels=target_encoder.classes_, yticklabels=target_encoder.classes_)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Actual vs Predicted Heatmap")
plt.savefig("actual_vs_predicted_heatmap.png")
plt.show()