import os
import utils
import pickle
import numpy as np
import gradio as gr
import tensorflow as tf
import matplotlib.pyplot as plt
from ttictoc import tic,toc
from keras.models import load_model
from urllib.request import urlretrieve
'''--------------------------- Descarga de modelos ----------------------------'''
# 3D U-Net
if not os.path.exists("unet.h5"):
urlretrieve("https://dl.dropboxusercontent.com/s/ay5q8caqzlad7h5/unet.h5?dl=0", "unet.h5")
# Med3D
if not os.path.exists("resnet_50_23dataset.pth"):
urlretrieve("https://dl.dropboxusercontent.com/s/otxsgx3e31d5h9i/resnet_50_23dataset.pth?dl=0", "resnet_50_23dataset.pth")
# Clasificador de imágen SVM
if not os.path.exists("svm_model.pickle"):
urlretrieve("https://dl.dropboxusercontent.com/s/n3tb3r6oyf06xfx/svm_model.pickle?dl=0", "svm_model.pickle")
# Nivel de riesgo
if not os.path.exists("mlp_probabilidad.h5"):
urlretrieve("https://dl.dropboxusercontent.com/s/78fjlg374mvjygd/mlp_probabilidad.h5?dl=0", "mlp_probabilidad.h5")
# Scaler para scores
if not os.path.exists("scaler.pickle"):
urlretrieve("https://dl.dropboxusercontent.com/s/ow6pe4k45r3xkbl/scaler.pickle?dl=0", "scaler.pickle")
path_3d_unet = 'unet.h5'
weight_path = 'resnet_50_23dataset.pth'
svm_path = "svm_model.pickle"
prob_model_path = "mlp_probabilidad.h5"
scaler_path = "scaler.pickle"
'''---------------------------- Carga de modelos ------------------------------'''
# 3D U-Net
with tf.device("cpu:0"):
model_unet = utils.import_3d_unet(path_3d_unet)
# MedNet
device_ids = [0]
mednet_model = utils.create_mednet(weight_path, device_ids)
# SVM model
svm_model = pickle.load(open(svm_path, 'rb'))
# Nivel de riesgo
with tf.device("cpu:0"):
prob_model = load_model(prob_model_path)
# Scaler
scaler = pickle.load(open(scaler_path, 'rb'))
'''-------------------------------- Funciones ---------------------------------'''
def load_img(file):
sitk, array = utils.load_img(file.name)
# Redimención
mri_image = np.transpose(array)
mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0)
# Rotación
mri_image = mri_image.astype(np.float32)
mri_image = np.rot90(mri_image, axes=(1,2))
return sitk, mri_image
def show_img(img, mri_slice, update):
fig = plt.figure()
plt.imshow(img[mri_slice,:,:], cmap='gray')
if update == True:
return fig, gr.update(visible=True), gr.update(visible=True)
else:
return fig
# def show_brain(brain, brain_slice):
# fig = plt.figure()
# plt.imshow(brain[brain_slice,:,:], cmap='gray')
# return fig, gr.update(visible=True)
def process_img(img, brain_slice):
# progress(None,desc="Processing...")
with tf.device("cpu:0"):
brain = utils.brain_stripping(img, model_unet)
fig, update_slider, _ = show_img(brain, brain_slice, update=True)
return brain, fig, update_slider, gr.update(visible=True)
def get_diagnosis(brain_img, age, MMSE, GDSCALE, CDR, FAQ, NPI, sex):
# Extracción de características de imagen
features = utils.get_features(brain_img, mednet_model)
# Clasificación de imagen
label_img = np.array([svm_model.predict(features)])
if sex == "Male":
sex_dum = 1
else:
sex_dum = 0
scores = np.array([age, MMSE, GDSCALE, CDR, FAQ, NPI, sex_dum, label_img])
print(scores)
# Normalización de scores
scores_norm = scaler.transform(scores.reshape(1,-1))
print(scores_norm)
with tf.device("cpu:0"):
# Probabilidad de tener MCI
prob = prob_model.predict(scores_norm)[0,0]
# Probabilidad de tener MCI
print(prob)
diagnosis = f"The patient has a probability of {(100*prob):.2f}% of having MCI"
return gr.update(value=diagnosis)
def clear():
return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False), gr.Plot.update(value=None), gr.update(visible=False), gr.update(value="The diagnosis will show here..."), gr.update(visible=False), gr.update(visible=False)
'''--------------------------------- Interfaz ---------------------------------'''
with gr.Blocks(theme=gr.themes.Base()) as demo:
with gr.Row():
# gr.HTML(r"""
""")
gr.HTML(r"""
""")
# gr.Markdown("""
# # SIMCI
# Interfaz de SIMCI
# """)
# Inputs
with gr.Row():
with gr.Column(variant="panel", scale=1):
gr.Markdown('Patient Information
')
with gr.Tab("Personal data"):
# Objeto para subir archivo nifti
input_name = gr.Textbox(placeholder='Enter the patient name', label='Patient name')
input_age = gr.Number(label='Age')
input_phone_num = gr.Number(label='Phone number')
input_emer_name = gr.Textbox(placeholder='Enter the emergency contact name', label='Emergency contact name')
input_emer_phone_num = gr.Number(label='Emergency contact name phone number')
input_sex = gr.Dropdown(["Male", "Female"], label="Sex")
with gr.Tab("Clinical data"):
input_MMSE = gr.Slider(minimum=0,
maximum=30,
value=0,
step=1,
label="MMSE total score")
input_GDSCALE = gr.Slider(minimum=0,
maximum=12,
value=0,
step=1,
label="GDSCALE total score")
input_CDR = gr.Slider(minimum=0,
maximum=3,
value=0,
step=0.5,
label="Global CDR")
input_FAQ = gr.Slider(minimum=0,
maximum=30,
value=0,
step=1,
label="FAQ total score")
input_NPI_Q = gr.Slider(minimum=0,
maximum=30,
value=0,
step=1,
label="NPI-Q total score")
with gr.Tab("Vital Signs"):
input_Diastolic_blood_pressure = gr.Number(label='Diastolic Blood Pressure(mm Hg)')
input_Systolic_blood_pressure = gr.Number(label='Systolic Blood Pressure(mm Hg)')
input_Body_heigth = gr.Number(label='Body heigth (cm)')
input_Body_weight = gr.Number(label='Body weigth (kg)')
input_Heart_rate = gr.Number(label='Heart rate (bpm)')
input_Respiratory_rate = gr.Number(label='Respiratory rate (bpm)')
input_Body_temperature = gr.Number(label='Body temperature (°C)')
input_Pluse_oximetry = gr.Number(label='Pluse oximetry (%)')
with gr.Tab("Medications"):
input_medications = gr.Textbox(label='Medications', lines=5)
input_allergies = gr.Textbox(label='Allergies', lines=5)
input_file = gr.File(file_count="single", label="MRI Image File (.nii)")
with gr.Row():
# Botón para cargar imagen
load_img_button = gr.Button(value="Load")
# Botón para borrar
clear_button = gr.Button(value="Clear")
# Botón para procesar imagen
process_button = gr.Button(value="Process MRI", visible=False, variant="primary")
# Botón para obtener diagnostico
diagnostic_button = gr.Button(value="Get diagnosis", visible=False, variant="primary")
# Outputs
with gr.Column(variant="panel", scale=1):
gr.Markdown('MRI visualization
')
with gr.Box():
gr.Markdown('Loaded MRI
')
# Plot para imágen original
plot_img_original = gr.Plot(show_label=False)
# Slider para imágen original
mri_slider = gr.Slider(minimum=0,
maximum=192,
value=100,
step=1,
label="MRI Slice",
visible=False)
with gr.Box():
gr.Markdown('Proccessed MRI
')
# Plot para imágen procesada
plot_brain = gr.Plot(show_label=False, visible=True)
# Slider para imágen procesada
brain_slider = gr.Slider(minimum=0,
maximum=192,
value=100,
step=1,
label="MRI Slice",
visible=False)
with gr.Box():
gr.Markdown('Diagnosis
')
# Texto del diagnostico
diagnosis_text = gr.Textbox(label="Diagnosis",interactive=False, placeholder="The diagnosis will show here...")
# componentes =
# Variables
original_input_sitk = gr.State()
original_input_img = gr.State()
brain_img = gr.State()
update_true = gr.State(True)
update_false = gr.State(False)
# Cambios
# Cargar imagen nueva
input_file.change(load_img,
input_file,
[original_input_sitk, original_input_img])
# Mostrar imagen nueva
load_img_button.click(show_img,
[original_input_img, mri_slider, update_true],
[plot_img_original, mri_slider, process_button])
# Actualizar imagen original
mri_slider.change(show_img,
[original_input_img, mri_slider, update_false],
[plot_img_original])
# Procesar imagen
process_button.click(fn=process_img,
inputs=[original_input_sitk, brain_slider],
outputs=[brain_img,plot_brain,brain_slider, diagnostic_button])
# Actualizar imagen procesada
brain_slider.change(show_img,
[brain_img, brain_slider, update_false],
[plot_brain])
# Actualizar diagnostico
diagnostic_button.click(fn=get_diagnosis,
inputs=[brain_img, input_age, input_MMSE, input_GDSCALE, input_CDR, input_FAQ, input_NPI_Q, input_sex],
outputs=[diagnosis_text])
# Limpiar campos
clear_button.click(fn=clear,
outputs=[input_file, plot_img_original, mri_slider, plot_brain, brain_slider, diagnosis_text, process_button, diagnostic_button])
if __name__ == "__main__":
# demo.queue(concurrency_count=20)
demo.launch()