Spaces:
Running
Running
import pandas as pd | |
import lightgbm as lgb | |
import xgboost as xgb | |
import gradio as gr | |
import joblib | |
import os | |
from obesity_rp import config as cfg | |
# Global variables to store loaded models, their columns, and the label encoder | |
loaded_models = {} | |
loaded_model_columns_map = {} | |
label_encoder = None | |
def load_model_artifacts(model_name): | |
""" | |
Loads the trained model, feature columns, and the label encoder. | |
""" | |
model_file = os.path.join(cfg.MODEL_DIR, f"obesity_{model_name}_model.joblib") | |
columns_file = os.path.join(cfg.MODEL_DIR, f"{model_name}_model_columns.joblib") | |
encoder_file = os.path.join(cfg.MODEL_DIR, "label_encoder.joblib") | |
if not all(os.path.exists(f) for f in [model_file, columns_file, encoder_file]): | |
raise FileNotFoundError( | |
f"Model artifacts for '{model_name}' not found. Please ensure all required files exist." | |
) | |
loaded_model = joblib.load(model_file) | |
loaded_model_columns = joblib.load(columns_file) | |
le = joblib.load(encoder_file) | |
print( | |
f"{model_name} Model, feature columns, and label encoder loaded for prediction." | |
) | |
return loaded_model, loaded_model_columns, le | |
def predict_obesity_risk( | |
model_choice, | |
Gender, | |
Age, | |
Height, | |
Weight, | |
family_history_with_overweight, | |
FAVC, | |
FCVC, | |
NCP, | |
CAEC, | |
SMOKE, | |
CH2O, | |
SCC, | |
FAF, | |
TUE, | |
CALC, | |
MTRANS, | |
): | |
""" | |
Predicts obesity risk based on input features and chosen model. | |
""" | |
global label_encoder | |
if model_choice not in loaded_models: | |
try: | |
model, columns, le = load_model_artifacts(model_choice) | |
loaded_models[model_choice] = model | |
loaded_model_columns_map[model_choice] = columns | |
if label_encoder is None: | |
label_encoder = le | |
except FileNotFoundError as e: | |
return f"Error: {e}. Model '{model_choice}' not found. Please train the model first." | |
else: | |
model = loaded_models[model_choice] | |
columns = loaded_model_columns_map[model_choice] | |
le = label_encoder | |
# Create a dictionary to hold the input data | |
input_data_dict = { | |
"Age": Age, | |
"Height": Height, | |
"Weight": Weight, | |
"FCVC": FCVC, | |
"NCP": NCP, | |
"CH2O": CH2O, | |
"FAF": FAF, | |
"TUE": TUE, | |
} | |
input_df = pd.DataFrame(0, index=[0], columns=columns) | |
for col, value in input_data_dict.items(): | |
if col in input_df.columns: | |
input_df.loc[0, col] = value | |
# Handle one-hot encoded categorical features | |
categorical_inputs = { | |
"Gender": Gender, | |
"family_history_with_overweight": family_history_with_overweight, | |
"FAVC": FAVC, | |
"CAEC": CAEC, | |
"SMOKE": SMOKE, | |
"SCC": SCC, | |
"CALC": CALC, | |
"MTRANS": MTRANS, | |
} | |
for col_prefix, value in categorical_inputs.items(): | |
column_name = f"{col_prefix}_{value}" | |
if column_name in input_df.columns: | |
input_df.loc[0, column_name] = 1 | |
input_df = input_df[columns] | |
prediction_proba = model.predict_proba(input_df)[0] | |
prediction_encoded = model.predict(input_df)[0] | |
prediction_label = le.inverse_transform([prediction_encoded])[0] | |
results = f"Using {model_choice} Model:\nPrediction: {prediction_label}\n\n--- Prediction Probabilities ---\n" | |
for i, class_name in enumerate(le.classes_): | |
prob = prediction_proba[i] * 100 | |
results += f"{class_name}: {prob:.2f}%\n" | |
return results | |
def launch_gradio_app(share=False): | |
""" | |
Launches the Gradio web application for obesity risk prediction. | |
""" | |
print("\n--- Starting Gradio App ---") | |
# Define Gradio input components | |
model_choice_input = gr.Dropdown( | |
choices=cfg.MODEL_CHOICES, label="Select Model", value=cfg.RANDOM_FOREST | |
) | |
gender_input = gr.Dropdown(choices=["Female", "Male"], label="Gender") | |
age_input = gr.Slider(minimum=1, maximum=100, step=1, label="Age") | |
height_input = gr.Slider(minimum=1.0, maximum=2.2, step=0.01, label="Height (m)") | |
weight_input = gr.Slider(minimum=30.0, maximum=200.0, step=0.1, label="Weight (kg)") | |
family_history_input = gr.Radio( | |
choices=["yes", "no"], label="Family History with Overweight" | |
) | |
favc_input = gr.Radio( | |
choices=["yes", "no"], label="Frequent consumption of high caloric food (FAVC)" | |
) | |
fcvc_input = gr.Slider( | |
minimum=1, | |
maximum=3, | |
step=1, | |
label="Frequency of consumption of vegetables (FCVC)", | |
) | |
ncp_input = gr.Slider( | |
minimum=1, maximum=4, step=1, label="Number of main meals (NCP)" | |
) | |
caec_input = gr.Dropdown( | |
choices=["no", "Sometimes", "Frequently", "Always"], | |
label="Consumption of food between meals (CAEC)", | |
) | |
smoke_input = gr.Radio(choices=["yes", "no"], label="SMOKE") | |
ch2o_input = gr.Slider( | |
minimum=1, maximum=3, step=1, label="Consumption of water daily (CH2O)" | |
) | |
scc_input = gr.Radio( | |
choices=["yes", "no"], label="Calories consumption monitoring (SCC)" | |
) | |
faf_input = gr.Slider( | |
minimum=0, maximum=3, step=1, label="Physical activity frequency (FAF)" | |
) | |
tue_input = gr.Slider( | |
minimum=0, maximum=2, step=1, label="Time using technology devices (TUE)" | |
) | |
calc_input = gr.Dropdown( | |
choices=["no", "Sometimes", "Frequently", "Always"], | |
label="Consumption of alcohol (CALC)", | |
) | |
mtrans_input = gr.Dropdown( | |
choices=["Automobile", "Motorbike", "Bike", "Public_Transportation", "Walking"], | |
label="Transportation used (MTRANS)", | |
) | |
output_text = gr.Textbox(label="Obesity Risk Prediction Result", lines=10) | |
iface = gr.Interface( | |
fn=predict_obesity_risk, | |
inputs=[ | |
model_choice_input, | |
gender_input, | |
age_input, | |
height_input, | |
weight_input, | |
family_history_input, | |
favc_input, | |
fcvc_input, | |
ncp_input, | |
caec_input, | |
smoke_input, | |
ch2o_input, | |
scc_input, | |
faf_input, | |
tue_input, | |
calc_input, | |
mtrans_input, | |
], | |
outputs=output_text, | |
title="Obesity Risk Prediction (Multi-Model)", | |
description="Select a machine learning model and enter patient details to predict the obesity risk category.", | |
examples=[ | |
[ | |
cfg.RANDOM_FOREST, | |
"Male", | |
25, | |
1.8, | |
85, | |
"yes", | |
"yes", | |
2, | |
3, | |
"Sometimes", | |
"no", | |
2, | |
"no", | |
1, | |
1, | |
"Frequently", | |
"Public_Transportation", | |
], | |
[ | |
cfg.LIGHTGBM, | |
"Female", | |
30, | |
1.65, | |
70, | |
"yes", | |
"yes", | |
3, | |
3, | |
"Frequently", | |
"no", | |
3, | |
"yes", | |
2, | |
0, | |
"Sometimes", | |
"Automobile", | |
], | |
[ | |
cfg.XGBOOST, | |
"Female", | |
21, | |
1.52, | |
56, | |
"yes", | |
"no", | |
3, | |
3, | |
"Sometimes", | |
"yes", | |
3, | |
"yes", | |
3, | |
0, | |
"Sometimes", | |
"Public_Transportation", | |
], | |
], | |
) | |
iface.launch(share=share) | |
print("--- Gradio App Launched ---") | |
if __name__ == "__main__": | |
launch_gradio_app(share=False) | |