import os import numpy as np import skimage.transform as trans from skimage.color import rgb2gray from unet.unet import unet from unet.unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM def predict_model(input, unet_type): model_path = "weights" h, w = 256, 256 input_shape = [h, w, 1] output_channels = 1 batch_size = 1 # convert image into numpy array and reshape it into model's input size img = trans.resize(input, (w, h)) result_img = img.copy() img = rgb2gray(img).reshape(1, h, w, 1) # Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm if unet_type == 'v0': # load the pretrained model model_name = "unetv0_sgd500_neptune" model_file = os.path.join(model_path, model_name + ".hdf5") model = unet(model_file) elif unet_type == 'v1': # load the pretrained model model_name = "unetv1_sgd500_neptune" model_file = os.path.join(model_path, model_name + ".hdf5") model = UNet_3Plus(input_shape, output_channels, model_file) elif unet_type == 'v2': # load the pretrained model model_name = "unetv2_sgd500_neptune" model_file = os.path.join(model_path, model_name + ".hdf5") model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file) else: # load the pretrained model model_name = "unetv3_sgd500_neptune" model_file = os.path.join(model_path, model_name + ".hdf5") model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file) # Predict and save the results as numpy array results = model.predict(img) # Preprocess the prediction from the model depending on the model if unet_type == 'v2' or unet_type == 'v3': pred = np.copy(results[0]) else: pred = np.copy(results) pred[pred >= 0.5] = 1 pred[pred < 0.5] = 0 output = np.array(pred[0][:,:,0]) # visualize the output mask seg_color = [0, 0, 255] masked = output != 0 result_img[masked] = seg_color return result_img