Spaces:
Paused
Paused
| # import os | |
| import torch | |
| import resnet | |
| import numpy as np | |
| import tensorflow as tf | |
| # import nibabel as nib | |
| import SimpleITK as sitk | |
| import segmentation_models_3D as sm | |
| from torch import nn | |
| # from ttictoc import tic,toc | |
| from skimage import morphology | |
| from keras import backend as K | |
| from scipy import ndimage as ndi | |
| from keras.models import load_model | |
| from patchify import patchify, unpatchify | |
| # from matplotlib import pyplot as plt | |
| # from matplotlib.widgets import Slider | |
| # Funci贸n que retorna modelo 3D U-Net para extracci贸n de cerebro | |
| def import_3d_unet(path_3d_unet): | |
| # M茅tricas de desempe帽o | |
| def dice_coefficient(y_true, y_pred): | |
| smoothing_factor = 1 | |
| flat_y_true = K.flatten(y_true) | |
| flat_y_pred = K.flatten(y_pred) | |
| return (2. * K.sum(flat_y_true * flat_y_pred) + smoothing_factor) / (K.sum(flat_y_true) + K.sum(flat_y_pred) + smoothing_factor) | |
| # Cargar modelo preentrenado | |
| # with tf.device('/cpu:0'): | |
| model = load_model(path_3d_unet, custom_objects={'dice_coefficient':dice_coefficient, 'iou_score':sm.metrics.IOUScore(threshold=0.5)}) | |
| return model | |
| # Funci贸n que caraga imagen en formato nifti, aplica filtro N4 y normaliza imagen | |
| def load_img(path): | |
| # Lectura de MRI T1 formato nifti | |
| inputImage = sitk.ReadImage(path, sitk.sitkFloat32) | |
| return inputImage, sitk.GetArrayFromImage(inputImage).astype(np.float32) | |
| # Funci贸n que remueve | |
| def brain_stripping(inputImage, model_unet): | |
| """----------------------Preprocesamiento imagen MRI-----------------------""" | |
| image = inputImage | |
| # N4 Bias Field Correction | |
| maskImage = sitk.OtsuThreshold(inputImage, 0, 1, 200) | |
| corrector = sitk.N4BiasFieldCorrectionImageFilter() | |
| corrected_image = corrector.Execute(image, maskImage) | |
| log_bias_field = corrector.GetLogBiasFieldAsImage(inputImage) | |
| corrected_image_full_resolution = inputImage / sitk.Exp(log_bias_field) | |
| #Normalizaci贸n | |
| image_normalized = sitk.GetArrayFromImage(corrected_image_full_resolution) | |
| image_normalized = (image_normalized-np.min(image_normalized))/(np.max(image_normalized)-np.min(image_normalized)) | |
| image_normalized = image_normalized.astype(np.float32) | |
| # Redimenci贸n | |
| mri_image = np.transpose(image_normalized) | |
| mri_image = np.append(mri_image, np.zeros((192-mri_image.shape[0],256,256,)), axis=0) | |
| # Rotaci贸n | |
| mri_image = mri_image.astype(np.float32) | |
| mri_image = np.rot90(mri_image, axes=(1,2)) | |
| # Volume sampling | |
| mri_patches = patchify(mri_image, (64, 64, 64), step=64) | |
| """--------------------Predicci贸n de m谩scara de cerebro--------------------""" | |
| # M谩scara de cerebro para cada vol煤men | |
| mask_patches = [] | |
| for i in range(mri_patches.shape[0]): | |
| for j in range(mri_patches.shape[1]): | |
| for k in range(mri_patches.shape[2]): | |
| single_patch = np.expand_dims(mri_patches[i,j,k,:,:,:], axis=0) | |
| single_patch_prediction = model_unet.predict(single_patch) | |
| single_patch_prediction_th = (single_patch_prediction[0,:,:,:,0] > 0.5).astype(np.uint8) | |
| mask_patches.append(single_patch_prediction_th) | |
| # Conversi贸n a numpy array | |
| predicted_patches = np.array(mask_patches) | |
| # Reshape para proceso de reconstrucci贸n | |
| predicted_patches_reshaped = np.reshape(predicted_patches, | |
| (mri_patches.shape[0], mri_patches.shape[1], mri_patches.shape[2], | |
| mri_patches.shape[3], mri_patches.shape[4], mri_patches.shape[5]) ) | |
| # Reconstrucci贸n m谩scara | |
| reconstructed_mask = unpatchify(predicted_patches_reshaped, mri_image.shape) | |
| # Suavizado m谩scara | |
| corrected_mask = ndi.binary_closing(reconstructed_mask, structure=morphology.ball(2)).astype(np.uint8) | |
| # Eliminaci贸n de volumenes ruido | |
| no_noise_mask = corrected_mask.copy() | |
| mask_labeled = morphology.label(corrected_mask, background=0, connectivity=3) | |
| label_count = np.unique(mask_labeled, return_counts=True) | |
| brain_label = np.argmax(label_count[1][1:]) + 1 | |
| no_noise_mask[np.where(mask_labeled != brain_label)] = 0 | |
| # Elimicaci贸n huecos y hendiduras | |
| filled_mask = ndi.binary_closing(no_noise_mask, structure=morphology.ball(12)).astype(np.uint8) | |
| """-------------------------Extracci贸n de cerebro--------------------------""" | |
| # Aplicar m谩scara a imagen mri | |
| mri_brain = np.multiply(mri_image,filled_mask) | |
| return mri_brain | |
| # Funci贸n que retorna modelo MedNet | |
| def create_mednet(weight_path, device_ids): | |
| # Clase para agregar capa totalmente conectada | |
| class simci_net(nn.Module): | |
| def __init__(self): | |
| super(simci_net, self).__init__() | |
| self.pretrained_model = resnet.resnet50(sample_input_D=192, sample_input_H=256, sample_input_W=256, num_seg_classes=2, no_cuda = False) | |
| self.pretrained_model.conv_seg = nn.Sequential(nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)), | |
| nn.Flatten(start_dim=1)) | |
| def forward(self, x): | |
| x = self.pretrained_model(x) | |
| return x | |
| # Path con pesos preentrenados | |
| weight_path = weight_path | |
| # Lista de GPUs para utilizar | |
| device_ids = device_ids | |
| # Generar red | |
| simci_model = simci_net() | |
| # Distribuir en varias GPUs | |
| simci_model = torch.nn.DataParallel(simci_model, device_ids = device_ids) | |
| simci_model.to(f'cuda:{simci_model.device_ids[0]}') | |
| # Diccionario state | |
| net_dict = simci_model.state_dict() | |
| # Cargar pesos | |
| weight = torch.load(weight_path, map_location=torch.device(f'cuda:{simci_model.device_ids[0]}')) | |
| # Transferencia de aprendizaje | |
| pretrain_dict = {} | |
| for k, v in weight['state_dict'].items(): | |
| if k.replace("module.", "module.pretrained_model.") in net_dict.keys(): | |
| pretrain_dict[k.replace("module.", "module.pretrained_model.")] = v | |
| # pretrain_dict = {k.replace("module.", ""): v for k, v in weight['state_dict'].items() if k.replace("module.", "") in net_dict.keys()} | |
| net_dict.update(pretrain_dict) | |
| simci_model.load_state_dict(net_dict) | |
| # Bloqueo de parametros mednet | |
| for param in simci_model.module.pretrained_model.parameters(): | |
| param.requires_grad = False | |
| simci_model.eval() # Modelo en modo evaluaci贸n | |
| return simci_model | |
| # Funci贸n que extrae caracter铆sticas de cerebro | |
| def get_features(brain, mednet_model): | |
| with torch.no_grad(): | |
| # Convertir a tensor | |
| data = torch.from_numpy(np.expand_dims(np.expand_dims(brain,axis=0), axis=0)) | |
| # Enviar imagen a GPU | |
| data = data.to(f'cuda:{mednet_model.device_ids[0]}') | |
| # Extraer Caracter铆sticas | |
| features = mednet_model(data) # Forward | |
| features = features.cpu().numpy() | |
| torch.cuda.empty_cache() | |
| return features |