Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import joblib | |
# Load trained components | |
model = joblib.load("voting_model_multiclass.pkl") | |
scaler = joblib.load("scaler.pkl") | |
label_encoders = joblib.load("feature_label_encoders.pkl") | |
target_encoder = joblib.load("target_label_encoder.pkl") | |
# Feature order used during training | |
feature_order = [ | |
'Age', 'Sex', 'Socioeconomic_Status', 'Vitamin_D_Level_ng/ml', | |
'Vitamin_D_Status', 'Vitamin_D_Supplemented', 'Bacterial_Infection', | |
'Viral_Infection', 'Co_Infection', 'IL6_pg/ml', 'IL8_pg/ml' | |
] | |
def predict_arti_severity( | |
age, sex, ses, vit_d_level, vit_d_status, vit_d_supp, | |
bacterial, viral, co_infect, il6, il8 | |
): | |
# Create a single row DataFrame | |
input_data = pd.DataFrame([[ | |
age, sex, ses, vit_d_level, vit_d_status, vit_d_supp, | |
bacterial, viral, co_infect, il6, il8 | |
]], columns=feature_order) | |
# Encode categorical columns | |
for col in input_data.select_dtypes(include='object').columns: | |
le = label_encoders[col] | |
input_data[col] = le.transform(input_data[col]) | |
# Scale the input | |
input_scaled = scaler.transform(input_data) | |
# Predict | |
pred = model.predict(input_scaled)[0] | |
pred_label = target_encoder.inverse_transform([pred])[0] | |
return f"Predicted ARTI Severity: {pred_label}" | |
# Define Gradio interface | |
interface = gr.Interface( | |
fn=predict_arti_severity, | |
inputs=[ | |
gr.Number(label="Age (years)", value=2), | |
gr.Radio(choices=['Male', 'Female'], label="Sex"), | |
gr.Radio(choices=['Low', 'Middle', 'High'], label="Socioeconomic Status"), | |
gr.Number(label="Vitamin D Level (ng/ml)", value=20), | |
gr.Radio(choices=['Deficient', 'Insufficient', 'Sufficient'], label="Vitamin D Status"), | |
gr.Radio(choices=['Yes', 'No'], label="Vitamin D Supplemented"), | |
gr.Dropdown(choices=['Streptococcus pneumoniae', 'Klebsiella pneumoniae', 'Staphylococcus aureus', 'None'], label="Bacterial Infection"), | |
gr.Dropdown(choices=['RSV', 'Influenza A', 'Influenza B', 'Adenovirus', 'Rhinovirus', 'Metapneumovirus', 'Parainfluenza', 'None'], label="Viral Infection"), | |
gr.Radio(choices=['Yes', 'No'], label="Co-Infection"), | |
gr.Number(label="IL-6 (pg/ml)", value=25), | |
gr.Number(label="IL-8 (pg/ml)", value=40) | |
], | |
outputs=gr.Textbox(label="Prediction Result"), | |
title="ARTI Severity Predictor", | |
description="Predict ARTI severity in children based on vitamin D, infection data, and inflammation markers." | |
) | |
# Launch app | |
interface.launch() | |