Spaces:
Paused
Paused
import os | |
import utils | |
import pickle | |
import numpy as np | |
import gradio as gr | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from ttictoc import tic,toc | |
from keras.models import load_model | |
from urllib.request import urlretrieve | |
'''--------------------------- Descarga de modelos ----------------------------''' | |
# 3D U-Net | |
if not os.path.exists("unet.h5"): | |
urlretrieve("https://dl.dropboxusercontent.com/s/ay5q8caqzlad7h5/unet.h5?dl=0", "unet.h5") | |
# Med3D | |
if not os.path.exists("resnet_50_23dataset.pth"): | |
urlretrieve("https://dl.dropboxusercontent.com/s/otxsgx3e31d5h9i/resnet_50_23dataset.pth?dl=0", "resnet_50_23dataset.pth") | |
# Clasificador de im谩gen SVM | |
if not os.path.exists("svm_model.pickle"): | |
urlretrieve("https://dl.dropboxusercontent.com/s/n3tb3r6oyf06xfx/svm_model.pickle?dl=0", "svm_model.pickle") | |
# Nivel de riesgo | |
if not os.path.exists("mlp_probabilidad.h5"): | |
urlretrieve("https://dl.dropboxusercontent.com/s/78fjlg374mvjygd/mlp_probabilidad.h5?dl=0", "mlp_probabilidad.h5") | |
# Scaler para scores | |
if not os.path.exists("scaler.pickle"): | |
urlretrieve("https://dl.dropboxusercontent.com/s/ow6pe4k45r3xkbl/scaler.pickle?dl=0", "scaler.pickle") | |
path_3d_unet = 'unet.h5' | |
weight_path = 'resnet_50_23dataset.pth' | |
svm_path = "svm_model.pickle" | |
prob_model_path = "mlp_probabilidad.h5" | |
scaler_path = "scaler.pickle" | |
'''---------------------------- Carga de modelos ------------------------------''' | |
# 3D U-Net | |
with tf.device("cpu:0"): | |
model_unet = utils.import_3d_unet(path_3d_unet) | |
# MedNet | |
device_ids = [0] | |
mednet_model = utils.create_mednet(weight_path, device_ids) | |
# SVM model | |
svm_model = pickle.load(open(svm_path, 'rb')) | |
# Nivel de riesgo | |
with tf.device("cpu:0"): | |
prob_model = load_model(prob_model_path) | |
# Scaler | |
scaler = pickle.load(open(scaler_path, 'rb')) | |
'''-------------------------------- Funciones ---------------------------------''' | |
def load_img(file): | |
sitk, array = utils.load_img(file.name) | |
# Redimenci贸n | |
mri_image = np.transpose(array) | |
mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0) | |
# Rotaci贸n | |
mri_image = mri_image.astype(np.float32) | |
mri_image = np.rot90(mri_image, axes=(1,2)) | |
return sitk, mri_image | |
def show_img(img, mri_slice, update): | |
fig = plt.figure() | |
plt.imshow(img[mri_slice,:,:], cmap='gray') | |
if update == True: | |
return fig, gr.update(visible=True), gr.update(visible=True) | |
else: | |
return fig | |
# def show_brain(brain, brain_slice): | |
# fig = plt.figure() | |
# plt.imshow(brain[brain_slice,:,:], cmap='gray') | |
# return fig, gr.update(visible=True) | |
def process_img(img, brain_slice): | |
# progress(None,desc="Processing...") | |
with tf.device("cpu:0"): | |
brain = utils.brain_stripping(img, model_unet) | |
fig, update_slider, _ = show_img(brain, brain_slice, update=True) | |
return brain, fig, update_slider, gr.update(visible=True) | |
def get_diagnosis(brain_img, age, MMSE, GDSCALE, CDR, FAQ, NPI, sex): | |
# Extracci贸n de caracter铆sticas de imagen | |
features = utils.get_features(brain_img, mednet_model) | |
# Clasificaci贸n de imagen | |
label_img = np.array([svm_model.predict(features)]) | |
if sex == "Male": | |
sex_dum = 1 | |
else: | |
sex_dum = 0 | |
scores = np.array([age, MMSE, GDSCALE, CDR, FAQ, NPI, sex_dum, label_img]) | |
print(scores) | |
# Normalizaci贸n de scores | |
scores_norm = scaler.transform(scores.reshape(1,-1)) | |
print(scores_norm) | |
with tf.device("cpu:0"): | |
# Probabilidad de tener MCI | |
prob = prob_model.predict(scores_norm)[0,0] | |
# Probabilidad de tener MCI | |
print(prob) | |
diagnosis = f"The patient has a probability of {(100*prob):.2f}% of having MCI" | |
return gr.update(value=diagnosis) | |
def clear(): | |
return gr.File.update(value=None), gr.Plot.update(value=None), gr.update(visible=False), gr.Plot.update(value=None), gr.update(visible=False), gr.update(value="The diagnosis will show here..."), gr.update(visible=False), gr.update(visible=False) | |
'''--------------------------------- Interfaz ---------------------------------''' | |
with gr.Blocks(theme=gr.themes.Base()) as demo: | |
with gr.Row(): | |
# gr.HTML(r"""<center><img src='https://user-images.githubusercontent.com/66338785/233529518-33e8bcdb-146f-49e8-94c4-27d6529ce4f7.png' width="30%" height="30%"></center>""") | |
gr.HTML(r"""<center><img src='https://user-images.githubusercontent.com/66338785/233531457-f368e04b-5099-42a8-906d-6f1250ea0f1e.png' width="40%" height="40%"></center>""") | |
# gr.Markdown(""" | |
# # SIMCI | |
# Interfaz de SIMCI | |
# """) | |
# Inputs | |
with gr.Row(): | |
with gr.Column(variant="panel", scale=1): | |
gr.Markdown('<h2 style="text-align: center; color:#235784;">Patient Information</h2>') | |
with gr.Tab("Personal data"): | |
# Objeto para subir archivo nifti | |
input_name = gr.Textbox(placeholder='Enter the patient name', label='Patient name') | |
input_age = gr.Number(label='Age') | |
input_phone_num = gr.Number(label='Phone number') | |
input_emer_name = gr.Textbox(placeholder='Enter the emergency contact name', label='Emergency contact name') | |
input_emer_phone_num = gr.Number(label='Emergency contact name phone number') | |
input_sex = gr.Dropdown(["Male", "Female"], label="Sex") | |
with gr.Tab("Clinical data"): | |
input_MMSE = gr.Slider(minimum=0, | |
maximum=30, | |
value=0, | |
step=1, | |
label="MMSE total score") | |
input_GDSCALE = gr.Slider(minimum=0, | |
maximum=12, | |
value=0, | |
step=1, | |
label="GDSCALE total score") | |
input_CDR = gr.Slider(minimum=0, | |
maximum=3, | |
value=0, | |
step=0.5, | |
label="Global CDR") | |
input_FAQ = gr.Slider(minimum=0, | |
maximum=30, | |
value=0, | |
step=1, | |
label="FAQ total score") | |
input_NPI_Q = gr.Slider(minimum=0, | |
maximum=30, | |
value=0, | |
step=1, | |
label="NPI-Q total score") | |
with gr.Tab("Vital Signs"): | |
input_Diastolic_blood_pressure = gr.Number(label='Diastolic Blood Pressure(mm Hg)') | |
input_Systolic_blood_pressure = gr.Number(label='Systolic Blood Pressure(mm Hg)') | |
input_Body_heigth = gr.Number(label='Body heigth (cm)') | |
input_Body_weight = gr.Number(label='Body weigth (kg)') | |
input_Heart_rate = gr.Number(label='Heart rate (bpm)') | |
input_Respiratory_rate = gr.Number(label='Respiratory rate (bpm)') | |
input_Body_temperature = gr.Number(label='Body temperature (掳C)') | |
input_Pluse_oximetry = gr.Number(label='Pluse oximetry (%)') | |
with gr.Tab("Medications"): | |
input_medications = gr.Textbox(label='Medications', lines=5) | |
input_allergies = gr.Textbox(label='Allergies', lines=5) | |
input_file = gr.File(file_count="single", label="MRI Image File (.nii)") | |
with gr.Row(): | |
# Bot贸n para cargar imagen | |
load_img_button = gr.Button(value="Load") | |
# Bot贸n para borrar | |
clear_button = gr.Button(value="Clear") | |
# Bot贸n para procesar imagen | |
process_button = gr.Button(value="Process MRI", visible=False, variant="primary") | |
# Bot贸n para obtener diagnostico | |
diagnostic_button = gr.Button(value="Get diagnosis", visible=False, variant="primary") | |
# Outputs | |
with gr.Column(variant="panel", scale=1): | |
gr.Markdown('<h2 style="text-align: center; color:#235784;">MRI visualization</h2>') | |
with gr.Box(): | |
gr.Markdown('<h4 style="color:#235784;">Loaded MRI</h4>') | |
# Plot para im谩gen original | |
plot_img_original = gr.Plot(show_label=False) | |
# Slider para im谩gen original | |
mri_slider = gr.Slider(minimum=0, | |
maximum=192, | |
value=100, | |
step=1, | |
label="MRI Slice", | |
visible=False) | |
with gr.Box(): | |
gr.Markdown('<h4 style="color:#235784;">Proccessed MRI</h4>') | |
# Plot para im谩gen procesada | |
plot_brain = gr.Plot(show_label=False, visible=True) | |
# Slider para im谩gen procesada | |
brain_slider = gr.Slider(minimum=0, | |
maximum=192, | |
value=100, | |
step=1, | |
label="MRI Slice", | |
visible=False) | |
with gr.Box(): | |
gr.Markdown('<h2 style="text-align: center; color:#235784;">Diagnosis</h2>') | |
# Texto del diagnostico | |
diagnosis_text = gr.Textbox(label="Diagnosis",interactive=False, placeholder="The diagnosis will show here...") | |
# componentes = | |
# Variables | |
original_input_sitk = gr.State() | |
original_input_img = gr.State() | |
brain_img = gr.State() | |
update_true = gr.State(True) | |
update_false = gr.State(False) | |
# Cambios | |
# Cargar imagen nueva | |
input_file.change(load_img, | |
input_file, | |
[original_input_sitk, original_input_img]) | |
# Mostrar imagen nueva | |
load_img_button.click(show_img, | |
[original_input_img, mri_slider, update_true], | |
[plot_img_original, mri_slider, process_button]) | |
# Actualizar imagen original | |
mri_slider.change(show_img, | |
[original_input_img, mri_slider, update_false], | |
[plot_img_original]) | |
# Procesar imagen | |
process_button.click(fn=process_img, | |
inputs=[original_input_sitk, brain_slider], | |
outputs=[brain_img,plot_brain,brain_slider, diagnostic_button]) | |
# Actualizar imagen procesada | |
brain_slider.change(show_img, | |
[brain_img, brain_slider, update_false], | |
[plot_brain]) | |
# Actualizar diagnostico | |
diagnostic_button.click(fn=get_diagnosis, | |
inputs=[brain_img, input_age, input_MMSE, input_GDSCALE, input_CDR, input_FAQ, input_NPI_Q, input_sex], | |
outputs=[diagnosis_text]) | |
# Limpiar campos | |
clear_button.click(fn=clear, | |
outputs=[input_file, plot_img_original, mri_slider, plot_brain, brain_slider, diagnosis_text, process_button, diagnostic_button]) | |
if __name__ == "__main__": | |
# demo.queue(concurrency_count=20) | |
demo.launch() | |