import os import utils import pickle import numpy as np import gradio as gr import tensorflow as tf import matplotlib.pyplot as plt from ttictoc import tic,toc from keras.models import load_model from urllib.request import urlretrieve '''--------------------------- Descarga de modelos ----------------------------''' # 3D U-Net if not os.path.exists("unet.h5"): urlretrieve("https://dl.dropboxusercontent.com/s/ay5q8caqzlad7h5/unet.h5?dl=0", "unet.h5") # Med3D if not os.path.exists("resnet_50_23dataset.pth"): urlretrieve("https://dl.dropboxusercontent.com/s/otxsgx3e31d5h9i/resnet_50_23dataset.pth?dl=0", "resnet_50_23dataset.pth") # Clasificador de imágen SVM if not os.path.exists("svm_model.pickle"): urlretrieve("https://dl.dropboxusercontent.com/s/n3tb3r6oyf06xfx/svm_model.pickle?dl=0", "svm_model.pickle") # Nivel de riesgo if not os.path.exists("mlp_probabilidad.h5"): urlretrieve("https://dl.dropboxusercontent.com/s/78fjlg374mvjygd/mlp_probabilidad.h5?dl=0", "mlp_probabilidad.h5") # Scaler para scores if not os.path.exists("scaler.pickle"): urlretrieve("https://dl.dropboxusercontent.com/s/ow6pe4k45r3xkbl/scaler.pickle?dl=0", "scaler.pickle") path_3d_unet = 'unet.h5' weight_path = 'resnet_50_23dataset.pth' svm_path = "svm_model.pickle" prob_model_path = "mlp_probabilidad.h5" scaler_path = "scaler.pickle" '''---------------------------- Carga de modelos ------------------------------''' # 3D U-Net with tf.device("cpu:0"): model_unet = utils.import_3d_unet(path_3d_unet) # MedNet device_ids = [0] mednet_model = utils.create_mednet(weight_path, device_ids) # SVM model svm_model = pickle.load(open(svm_path, 'rb')) # Nivel de riesgo with tf.device("cpu:0"): prob_model = load_model(prob_model_path) # Scaler scaler = pickle.load(open(scaler_path, 'rb')) '''-------------------------------- Funciones ---------------------------------''' def load_img(file): sitk, array = utils.load_img(file.name) # Redimención mri_image = np.transpose(array) mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0) # Rotación mri_image = mri_image.astype(np.float32) mri_image = np.rot90(mri_image, axes=(1,2)) return sitk, mri_image def show_img(img, mri_slice, update): fig = plt.figure() plt.imshow(img[mri_slice,:,:], cmap='gray') if update == True: return fig, gr.update(visible=True), gr.update(visible=True) else: return fig # def show_brain(brain, brain_slice): # fig = plt.figure() # plt.imshow(brain[brain_slice,:,:], cmap='gray') # return fig, gr.update(visible=True) def process_img(img, brain_slice): # progress(None,desc="Processing...") with tf.device("cpu:0"): brain = utils.brain_stripping(img, model_unet) fig, update_slider, _ = show_img(brain, brain_slice, update=True) return brain, fig, update_slider, gr.update(visible=True) def get_diagnosis(brain_img, age, MMSE, GDSCALE, CDR, FAQ, NPI, sex): # Extracción de características de imagen features = utils.get_features(brain_img, mednet_model) # Clasificación de imagen label_img = np.array([svm_model.predict(features)]) if sex == "Male": sex_dum = 1 else: sex_dum = 0 scores = np.array([age, MMSE, GDSCALE, CDR, FAQ, NPI, sex_dum, label_img]) print(scores) # Normalización de scores scores_norm = scaler.transform(scores.reshape(1,-1)) print(scores_norm) with tf.device("cpu:0"): # Probabilidad de tener MCI prob = prob_model.predict(scores_norm)[0,0] # Probabilidad de tener MCI print(prob) diagnosis = f"The patient has a probability of {(100*prob):.2f}% of having MCI" return gr.update(value=diagnosis) def clear(): return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False), gr.Plot.update(value=None), gr.update(visible=False), gr.update(value="The diagnosis will show here..."), gr.update(visible=False), gr.update(visible=False) '''--------------------------------- Interfaz ---------------------------------''' with gr.Blocks(theme=gr.themes.Base()) as demo: with gr.Row(): # gr.HTML(r"""
""") gr.HTML(r"""
""") # gr.Markdown(""" # # SIMCI # Interfaz de SIMCI # """) # Inputs with gr.Row(): with gr.Column(variant="panel", scale=1): gr.Markdown('

Patient Information

') with gr.Tab("Personal data"): # Objeto para subir archivo nifti input_name = gr.Textbox(placeholder='Enter the patient name', label='Patient name') input_age = gr.Number(label='Age') input_phone_num = gr.Number(label='Phone number') input_emer_name = gr.Textbox(placeholder='Enter the emergency contact name', label='Emergency contact name') input_emer_phone_num = gr.Number(label='Emergency contact name phone number') input_sex = gr.Dropdown(["Male", "Female"], label="Sex") with gr.Tab("Clinical data"): input_MMSE = gr.Slider(minimum=0, maximum=30, value=0, step=1, label="MMSE total score") input_GDSCALE = gr.Slider(minimum=0, maximum=12, value=0, step=1, label="GDSCALE total score") input_CDR = gr.Slider(minimum=0, maximum=3, value=0, step=0.5, label="Global CDR") input_FAQ = gr.Slider(minimum=0, maximum=30, value=0, step=1, label="FAQ total score") input_NPI_Q = gr.Slider(minimum=0, maximum=30, value=0, step=1, label="NPI-Q total score") with gr.Tab("Vital Signs"): input_Diastolic_blood_pressure = gr.Number(label='Diastolic Blood Pressure(mm Hg)') input_Systolic_blood_pressure = gr.Number(label='Systolic Blood Pressure(mm Hg)') input_Body_heigth = gr.Number(label='Body heigth (cm)') input_Body_weight = gr.Number(label='Body weigth (kg)') input_Heart_rate = gr.Number(label='Heart rate (bpm)') input_Respiratory_rate = gr.Number(label='Respiratory rate (bpm)') input_Body_temperature = gr.Number(label='Body temperature (°C)') input_Pluse_oximetry = gr.Number(label='Pluse oximetry (%)') with gr.Tab("Medications"): input_medications = gr.Textbox(label='Medications', lines=5) input_allergies = gr.Textbox(label='Allergies', lines=5) input_file = gr.File(file_count="single", label="MRI Image File (.nii)") with gr.Row(): # Botón para cargar imagen load_img_button = gr.Button(value="Load") # Botón para borrar clear_button = gr.Button(value="Clear") # Botón para procesar imagen process_button = gr.Button(value="Process MRI", visible=False, variant="primary") # Botón para obtener diagnostico diagnostic_button = gr.Button(value="Get diagnosis", visible=False, variant="primary") # Outputs with gr.Column(variant="panel", scale=1): gr.Markdown('

MRI visualization

') with gr.Box(): gr.Markdown('

Loaded MRI

') # Plot para imágen original plot_img_original = gr.Plot(show_label=False) # Slider para imágen original mri_slider = gr.Slider(minimum=0, maximum=192, value=100, step=1, label="MRI Slice", visible=False) with gr.Box(): gr.Markdown('

Proccessed MRI

') # Plot para imágen procesada plot_brain = gr.Plot(show_label=False, visible=True) # Slider para imágen procesada brain_slider = gr.Slider(minimum=0, maximum=192, value=100, step=1, label="MRI Slice", visible=False) with gr.Box(): gr.Markdown('

Diagnosis

') # Texto del diagnostico diagnosis_text = gr.Textbox(label="Diagnosis",interactive=False, placeholder="The diagnosis will show here...") # componentes = # Variables original_input_sitk = gr.State() original_input_img = gr.State() brain_img = gr.State() update_true = gr.State(True) update_false = gr.State(False) # Cambios # Cargar imagen nueva input_file.change(load_img, input_file, [original_input_sitk, original_input_img]) # Mostrar imagen nueva load_img_button.click(show_img, [original_input_img, mri_slider, update_true], [plot_img_original, mri_slider, process_button]) # Actualizar imagen original mri_slider.change(show_img, [original_input_img, mri_slider, update_false], [plot_img_original]) # Procesar imagen process_button.click(fn=process_img, inputs=[original_input_sitk, brain_slider], outputs=[brain_img,plot_brain,brain_slider, diagnostic_button]) # Actualizar imagen procesada brain_slider.change(show_img, [brain_img, brain_slider, update_false], [plot_brain]) # Actualizar diagnostico diagnostic_button.click(fn=get_diagnosis, inputs=[brain_img, input_age, input_MMSE, input_GDSCALE, input_CDR, input_FAQ, input_NPI_Q, input_sex], outputs=[diagnosis_text]) # Limpiar campos clear_button.click(fn=clear, outputs=[input_file, plot_img_original, mri_slider, plot_brain, brain_slider, diagnosis_text, process_button, diagnostic_button]) if __name__ == "__main__": # demo.queue(concurrency_count=20) demo.launch()