sancho10's picture
Upload 20 files
f322e76 verified
raw
history blame
4.63 kB
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import (
classification_report, confusion_matrix, accuracy_score,
ConfusionMatrixDisplay
)
import warnings
warnings.filterwarnings('ignore')
# =====================
# 1. Load Dataset
# =====================
df = pd.read_csv("ARTI_Main_Data.csv")
# Handle missing values
df['Bacterial_Infection'] = df['Bacterial_Infection'].fillna("None")
df['Viral_Infection'] = df['Viral_Infection'].fillna("None")
# =====================
# 2. Set up features and multi-class target
# =====================
features = [
'Age', 'Sex', 'Socioeconomic_Status', 'Vitamin_D_Level_ng/ml',
'Vitamin_D_Status', 'Vitamin_D_Supplemented', 'Bacterial_Infection',
'Viral_Infection', 'Co_Infection', 'IL6_pg/ml', 'IL8_pg/ml'
]
target = 'ARTI_Severity'
# =====================
# 3. Encode features and target
# =====================
df_encoded = df[features].copy()
cat_cols = df_encoded.select_dtypes(include=['object']).columns
label_encoders = {}
for col in cat_cols:
le = LabelEncoder()
df_encoded[col] = le.fit_transform(df_encoded[col])
label_encoders[col] = le
# Encode target (multi-class)
target_encoder = LabelEncoder()
df['ARTI_Severity_Label'] = target_encoder.fit_transform(df[target])
y = df['ARTI_Severity_Label']
# =====================
# 4. Scale numerical features
# =====================
scaler = StandardScaler()
X_scaled = scaler.fit_transform(df_encoded)
# =====================
# 5. Train-test split
# =====================
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=42, stratify=y
)
# =====================
# 6. Define Models
# =====================
log_reg = LogisticRegression(max_iter=500)
rf = RandomForestClassifier(n_estimators=100, random_state=42)
svm = SVC(probability=True)
# Voting classifier
voting_model = VotingClassifier(estimators=[
('lr', log_reg),
('rf', rf),
('svm', svm)
], voting='hard')
# =====================
# 7. Train Model
# =====================
voting_model.fit(X_train, y_train)
# Save model and preprocessors
joblib.dump(voting_model, "voting_model_multiclass.pkl")
joblib.dump(scaler, "scaler.pkl")
joblib.dump(label_encoders, "feature_label_encoders.pkl")
joblib.dump(target_encoder, "target_label_encoder.pkl")
# =====================
# 8. Evaluation
# =====================
y_pred = voting_model.predict(X_test)
print("\nπŸ“Š Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nπŸ“‘ Classification Report:\n", classification_report(y_test, y_pred, target_names=target_encoder.classes_))
print("\nβœ… Accuracy Score:", accuracy_score(y_test, y_pred))
# =====================
# 9. Visualizations
# =====================
# 1. Confusion Matrix
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_test, y_pred),
display_labels=target_encoder.classes_)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix_multiclass.png")
plt.show()
# 2. Feature Importance (fit rf separately for this)
rf.fit(X_train, y_train)
plt.figure(figsize=(8, 5))
importances = rf.feature_importances_
indices = np.argsort(importances)[::-1]
feature_names = df_encoded.columns
sns.barplot(x=importances[indices], y=np.array(feature_names)[indices], palette='viridis')
plt.title("Feature Importance (Random Forest)")
plt.xlabel("Importance Score")
plt.ylabel("Features")
plt.tight_layout()
plt.savefig("feature_importance_rf.png")
plt.show()
# 3. Class Distribution
plt.figure(figsize=(6, 4))
sns.countplot(x=df[target], palette='pastel')
plt.title("Distribution of ARTI Severity Classes")
plt.xlabel("ARTI Severity")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("class_distribution.png")
plt.show()
# 4. Actual vs Predicted Comparison
plt.figure(figsize=(8, 5))
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, fmt="d", cmap="YlGnBu",
xticklabels=target_encoder.classes_, yticklabels=target_encoder.classes_)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Actual vs Predicted Heatmap")
plt.savefig("actual_vs_predicted_heatmap.png")
plt.show()