File size: 2,654 Bytes
f322e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import pandas as pd
import numpy as np
import joblib

# Load trained components
model = joblib.load("voting_model_multiclass.pkl")
scaler = joblib.load("scaler.pkl")
label_encoders = joblib.load("feature_label_encoders.pkl")
target_encoder = joblib.load("target_label_encoder.pkl")

# Feature order used during training
feature_order = [
    'Age', 'Sex', 'Socioeconomic_Status', 'Vitamin_D_Level_ng/ml',
    'Vitamin_D_Status', 'Vitamin_D_Supplemented', 'Bacterial_Infection',
    'Viral_Infection', 'Co_Infection', 'IL6_pg/ml', 'IL8_pg/ml'
]

def predict_arti_severity(

    age, sex, ses, vit_d_level, vit_d_status, vit_d_supp,

    bacterial, viral, co_infect, il6, il8

):
    # Create a single row DataFrame
    input_data = pd.DataFrame([[
        age, sex, ses, vit_d_level, vit_d_status, vit_d_supp,
        bacterial, viral, co_infect, il6, il8
    ]], columns=feature_order)
    
    # Encode categorical columns
    for col in input_data.select_dtypes(include='object').columns:
        le = label_encoders[col]
        input_data[col] = le.transform(input_data[col])
    
    # Scale the input
    input_scaled = scaler.transform(input_data)
    
    # Predict
    pred = model.predict(input_scaled)[0]
    pred_label = target_encoder.inverse_transform([pred])[0]
    
    return f"Predicted ARTI Severity: {pred_label}"

# Define Gradio interface
interface = gr.Interface(
    fn=predict_arti_severity,
    inputs=[
        gr.Number(label="Age (years)", value=2),
        gr.Radio(choices=['Male', 'Female'], label="Sex"),
        gr.Radio(choices=['Low', 'Middle', 'High'], label="Socioeconomic Status"),
        gr.Number(label="Vitamin D Level (ng/ml)", value=20),
        gr.Radio(choices=['Deficient', 'Insufficient', 'Sufficient'], label="Vitamin D Status"),
        gr.Radio(choices=['Yes', 'No'], label="Vitamin D Supplemented"),
        gr.Dropdown(choices=['Streptococcus pneumoniae', 'Klebsiella pneumoniae', 'Staphylococcus aureus', 'None'], label="Bacterial Infection"),
        gr.Dropdown(choices=['RSV', 'Influenza A', 'Influenza B', 'Adenovirus', 'Rhinovirus', 'Metapneumovirus', 'Parainfluenza', 'None'], label="Viral Infection"),
        gr.Radio(choices=['Yes', 'No'], label="Co-Infection"),
        gr.Number(label="IL-6 (pg/ml)", value=25),
        gr.Number(label="IL-8 (pg/ml)", value=40)
    ],
    outputs=gr.Textbox(label="Prediction Result"),
    title="ARTI Severity Predictor",
    description="Predict ARTI severity in children based on vitamin D, infection data, and inflammation markers."
)

# Launch app
interface.launch()