Spaces:
Sleeping
Sleeping
import pickle | |
import gradio as gr | |
with open('models/risk-model.pck', 'rb') as f: | |
dv, model = pickle.load(f) | |
def predict_single(customer, dv, model): | |
x = dv.transform([customer]) | |
y_pred = model.predict_proba(x)[:, 1] | |
return (y_pred[0] >= 0.5, y_pred[0]) | |
def predict_risk( | |
sex: int, age: int, classification: int, patient_type: int, | |
pneumonia: bool, pregnancy: bool, diabetes: bool, copd: bool, | |
asthma: bool, inmsupr: bool, hypertension: bool, cardiovascular: bool, | |
renal_chronic: bool, other_disease: bool, obesity: bool, tobacco: bool, | |
usmer: int, medical_unit: int): | |
sex_value = 1 if sex == "Femenino" else 2 | |
customer = { | |
"sex": sex, | |
"age": age, | |
"clasiffication_final": classification, | |
"patient_type": patient_type, | |
"pneumonia": int(pneumonia), | |
"pregnant": int(pregnancy), | |
"diabetes": int(diabetes), | |
"copd": int(copd), | |
"asthma": int(asthma), | |
"inmsupr": int(inmsupr), | |
"hipertension": int(hypertension), | |
"cardiovascular": int(cardiovascular), | |
"renal_chronic": int(renal_chronic), | |
"other_disease": int(other_disease), | |
"obesity": int(obesity), | |
"tobacco": int(tobacco), | |
"usmer": usmer, | |
"medical_unit": medical_unit, | |
} | |
risk, prediction = predict_single(customer, dv, model) | |
return { | |
"Risk": bool(risk), | |
"Risk Probability": round(float(prediction), 4) | |
} | |
with gr.Blocks() as interface: | |
gr.Markdown("## COVID-19 ICU Risk Predictor") | |
gr.Markdown( | |
"Fill in the patient's details below. Fields have validations to ensure correct inputs." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Inputs") | |
sex = gr.Dropdown( | |
choices=["Femenino", "Masculino"], | |
label="Sexo", | |
value="Masculino", | |
interactive=True, | |
info="Seleccione el sexo del paciente (Femenino o Masculino).", | |
) | |
age = gr.Number( | |
label="Age", | |
interactive=True, | |
minimum=0, | |
maximum=120, | |
info="Enter the patient's age (0-120).", | |
) | |
classification = gr.Number( | |
label="Classification", | |
interactive=True, | |
value=1, | |
minimum=1, | |
maximum=3, | |
info="1-3 means the patient was diagnosed with COVID; 4+ means not diagnosed.", | |
) | |
patient_type = gr.Dropdown( | |
choices=["Returned Home", "Hospitalization"], | |
label="Patient Type", | |
value="Returned Home", | |
interactive=True, | |
info="1 for Returned Home, 2 for Hospitalization", | |
) | |
usmer = gr.Number( | |
label="USMER", | |
interactive=True, | |
value=1, | |
minimum=1, | |
maximum=3, | |
info="Medical units: 1 for First Level, 2 for Second Level, 3 for Third Level", | |
) | |
medical_unit = gr.Number( | |
label="Medical Unit", | |
interactive=True, | |
value=1, | |
minimum=1, | |
info="Type of institution of the National Health System.", | |
) | |
with gr.Column(): | |
gr.Markdown("### Binary Inputs") | |
with gr.Row(): | |
with gr.Column(): | |
pneumonia = gr.Checkbox(label="Pneumonia") | |
pregnancy = gr.Checkbox(label="Pregnancy") | |
diabetes = gr.Checkbox(label="Diabetes") | |
copd = gr.Checkbox(label="COPD") | |
asthma = gr.Checkbox(label="Asthma") | |
inmsupr = gr.Checkbox(label="Immunosuppression") | |
with gr.Column(): | |
hypertension = gr.Checkbox(label="Hypertension") | |
cardiovascular = gr.Checkbox(label="Cardiovascular Disease") | |
renal_chronic = gr.Checkbox(label="Chronic Renal Disease") | |
other_disease = gr.Checkbox(label="Other Disease") | |
obesity = gr.Checkbox(label="Obesity") | |
tobacco = gr.Checkbox(label="Tobacco Use") | |
predict_btn = gr.Button("Predict Risk") | |
output = gr.JSON(label="Prediction Result") | |
predict_btn.click( | |
predict_risk, | |
inputs=[ | |
sex, age, classification, patient_type, pneumonia, pregnancy, | |
diabetes, copd, asthma, inmsupr, hypertension, cardiovascular, | |
renal_chronic, other_disease, obesity, tobacco, usmer, medical_unit, | |
], | |
outputs=output, | |
) | |
if __name__ == "__main__": | |
interface.launch() | |