fatchecker / unet /unet_3plus.py
bumble-bee's picture
added unet files
6bf4d42
import tensorflow as tf
import tensorflow.keras as k
# Model Architecture
def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same',
is_bn=True, is_relu=True, n=2):
""" Custom function for conv2d:
Apply 3*3 convolutions with BN and relu.
"""
for i in range(1, n + 1):
x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size,
padding=padding, strides=strides,
kernel_regularizer=tf.keras.regularizers.l2(1e-4),
kernel_initializer=k.initializers.he_normal(seed=5))(x)
if is_bn:
x = k.layers.BatchNormalization()(x)
if is_relu:
x = k.activations.relu(x)
return x
def dotProduct(seg, cls):
B, H, W, N = k.backend.int_shape(seg)
seg = tf.reshape(seg, [-1, H * W, N])
final = tf.einsum("ijk,ik->ijk", seg, cls)
final = tf.reshape(final, [-1, H, W, N])
return final
""" UNet_3Plus """
def UNet_3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None):
filters = [64, 128, 256, 512, 1024]
input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3
""" Encoder"""
# block 1
e1 = conv_block(input_layer, filters[0]) # 320*320*64
# block 2
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
e2 = conv_block(e2, filters[1]) # 160*160*128
# block 3
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
e3 = conv_block(e3, filters[2]) # 80*80*256
# block 4
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
e4 = conv_block(e4, filters[3]) # 40*40*512
# block 5
# bottleneck layer
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
e5 = conv_block(e5, filters[4]) # 20*20*1024
""" Decoder """
cat_channels = filters[0]
cat_blocks = len(filters)
upsample_channels = cat_blocks * cat_channels
""" d4 """
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
""" d3 """
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
""" d2 """
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
""" d1 """
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
# last layer does not have batchnorm and relu
d = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
if OUTPUT_CHANNELS == 1:
output = k.activations.sigmoid(d)
else:
output = k.activations.softmax(d)
model = tf.keras.Model(inputs=input_layer, outputs=output, name='UNet_3Plus')
if(pretrained_weights):
model.load_weights(pretrained_weights)
return model
""" UNet_3Plus with Deep Supervison"""
def UNet_3Plus_DeepSup(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None):
filters = [64, 128, 256, 512, 1024]
input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3
""" Encoder"""
# block 1
e1 = conv_block(input_layer, filters[0]) # 320*320*64
# block 2
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
e2 = conv_block(e2, filters[1]) # 160*160*128
# block 3
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
e3 = conv_block(e3, filters[2]) # 80*80*256
# block 4
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
e4 = conv_block(e4, filters[3]) # 40*40*512
# block 5
# bottleneck layer
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
e5 = conv_block(e5, filters[4]) # 20*20*1024
""" Decoder """
cat_channels = filters[0]
cat_blocks = len(filters)
upsample_channels = cat_blocks * cat_channels
""" d4 """
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
""" d3 """
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
""" d2 """
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
""" d1 """
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
""" Deep Supervision Part"""
# last layer does not have batchnorm and relu
d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
# d1 = no need for upsampling
d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2)
d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3)
d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4)
e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)
if OUTPUT_CHANNELS == 1:
d1 = k.activations.sigmoid(d1)
d2 = k.activations.sigmoid(d2)
d3 = k.activations.sigmoid(d3)
d4 = k.activations.sigmoid(d4)
e5 = k.activations.sigmoid(e5)
else:
d1 = k.activations.softmax(d1)
d2 = k.activations.softmax(d2)
d3 = k.activations.softmax(d3)
d4 = k.activations.softmax(d4)
e5 = k.activations.softmax(e5)
model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup')
if(pretrained_weights):
model.load_weights(pretrained_weights)
return model
""" UNet_3Plus with Deep Supervison and Classification Guided Module"""
def UNet_3Plus_DeepSup_CGM(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None):
filters = [64, 128, 256, 512, 1024]
input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3
""" Encoder"""
# block 1
e1 = conv_block(input_layer, filters[0]) # 320*320*64
# block 2
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
e2 = conv_block(e2, filters[1]) # 160*160*128
# block 3
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
e3 = conv_block(e3, filters[2]) # 80*80*256
# block 4
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
e4 = conv_block(e4, filters[3]) # 40*40*512
# block 5, bottleneck layer
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
e5 = conv_block(e5, filters[4]) # 20*20*1024
""" Classification Guided Module. Part 1"""
cls = k.layers.Dropout(rate=0.5)(e5)
cls = k.layers.Conv2D(2, kernel_size=(1, 1), padding="same", strides=(1, 1))(cls)
cls = k.layers.GlobalMaxPooling2D()(cls)
cls = k.activations.sigmoid(cls)
cls = tf.argmax(cls, axis=-1)
cls = cls[..., tf.newaxis]
cls = tf.cast(cls, dtype=tf.float32, )
""" Decoder """
cat_channels = filters[0]
cat_blocks = len(filters)
upsample_channels = cat_blocks * cat_channels
""" d4 """
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
""" d3 """
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
""" d2 """
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
""" d1 """
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
""" Deep Supervision Part"""
# last layer does not have batchnorm and relu
d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
# d1 = no need for upsampling
d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2)
d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3)
d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4)
e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)
""" Classification Guided Module. Part 2"""
d1 = dotProduct(d1, cls)
d2 = dotProduct(d2, cls)
d3 = dotProduct(d3, cls)
d4 = dotProduct(d4, cls)
e5 = dotProduct(e5, cls)
if OUTPUT_CHANNELS == 1:
d1 = k.activations.sigmoid(d1)
d2 = k.activations.sigmoid(d2)
d3 = k.activations.sigmoid(d3)
d4 = k.activations.sigmoid(d4)
e5 = k.activations.sigmoid(e5)
else:
d1 = k.activations.softmax(d1)
d2 = k.activations.softmax(d2)
d3 = k.activations.softmax(d3)
d4 = k.activations.softmax(d4)
e5 = k.activations.softmax(e5)
model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup_CGM')
if(pretrained_weights):
model.load_weights(pretrained_weights)
return model