Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
import tensorflow.keras as k | |
# Model Architecture | |
def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same', | |
is_bn=True, is_relu=True, n=2): | |
""" Custom function for conv2d: | |
Apply 3*3 convolutions with BN and relu. | |
""" | |
for i in range(1, n + 1): | |
x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size, | |
padding=padding, strides=strides, | |
kernel_regularizer=tf.keras.regularizers.l2(1e-4), | |
kernel_initializer=k.initializers.he_normal(seed=5))(x) | |
if is_bn: | |
x = k.layers.BatchNormalization()(x) | |
if is_relu: | |
x = k.activations.relu(x) | |
return x | |
def dotProduct(seg, cls): | |
B, H, W, N = k.backend.int_shape(seg) | |
seg = tf.reshape(seg, [-1, H * W, N]) | |
final = tf.einsum("ijk,ik->ijk", seg, cls) | |
final = tf.reshape(final, [-1, H, W, N]) | |
return final | |
""" UNet_3Plus """ | |
def UNet_3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None): | |
filters = [64, 128, 256, 512, 1024] | |
input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3 | |
""" Encoder""" | |
# block 1 | |
e1 = conv_block(input_layer, filters[0]) # 320*320*64 | |
# block 2 | |
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 | |
e2 = conv_block(e2, filters[1]) # 160*160*128 | |
# block 3 | |
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 | |
e3 = conv_block(e3, filters[2]) # 80*80*256 | |
# block 4 | |
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 | |
e4 = conv_block(e4, filters[3]) # 40*40*512 | |
# block 5 | |
# bottleneck layer | |
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 | |
e5 = conv_block(e5, filters[4]) # 20*20*1024 | |
""" Decoder """ | |
cat_channels = filters[0] | |
cat_blocks = len(filters) | |
upsample_channels = cat_blocks * cat_channels | |
""" d4 """ | |
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 | |
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 | |
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 | |
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 | |
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 | |
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 | |
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 | |
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 | |
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 | |
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) | |
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 | |
""" d3 """ | |
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 | |
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 | |
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 | |
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 | |
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 | |
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 | |
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 | |
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 | |
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 | |
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) | |
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 | |
""" d2 """ | |
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 | |
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 | |
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 | |
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 | |
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 | |
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 | |
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) | |
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 | |
""" d1 """ | |
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 | |
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 | |
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 | |
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 | |
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 | |
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) | |
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 | |
# last layer does not have batchnorm and relu | |
d = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
if OUTPUT_CHANNELS == 1: | |
output = k.activations.sigmoid(d) | |
else: | |
output = k.activations.softmax(d) | |
model = tf.keras.Model(inputs=input_layer, outputs=output, name='UNet_3Plus') | |
if(pretrained_weights): | |
model.load_weights(pretrained_weights) | |
return model | |
""" UNet_3Plus with Deep Supervison""" | |
def UNet_3Plus_DeepSup(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None): | |
filters = [64, 128, 256, 512, 1024] | |
input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3 | |
""" Encoder""" | |
# block 1 | |
e1 = conv_block(input_layer, filters[0]) # 320*320*64 | |
# block 2 | |
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 | |
e2 = conv_block(e2, filters[1]) # 160*160*128 | |
# block 3 | |
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 | |
e3 = conv_block(e3, filters[2]) # 80*80*256 | |
# block 4 | |
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 | |
e4 = conv_block(e4, filters[3]) # 40*40*512 | |
# block 5 | |
# bottleneck layer | |
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 | |
e5 = conv_block(e5, filters[4]) # 20*20*1024 | |
""" Decoder """ | |
cat_channels = filters[0] | |
cat_blocks = len(filters) | |
upsample_channels = cat_blocks * cat_channels | |
""" d4 """ | |
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 | |
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 | |
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 | |
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 | |
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 | |
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 | |
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 | |
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 | |
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 | |
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) | |
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 | |
""" d3 """ | |
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 | |
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 | |
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 | |
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 | |
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 | |
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 | |
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 | |
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 | |
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 | |
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) | |
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 | |
""" d2 """ | |
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 | |
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 | |
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 | |
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 | |
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 | |
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 | |
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) | |
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 | |
""" d1 """ | |
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 | |
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 | |
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 | |
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 | |
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 | |
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) | |
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 | |
""" Deep Supervision Part""" | |
# last layer does not have batchnorm and relu | |
d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
# d1 = no need for upsampling | |
d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) | |
d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) | |
d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) | |
e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) | |
if OUTPUT_CHANNELS == 1: | |
d1 = k.activations.sigmoid(d1) | |
d2 = k.activations.sigmoid(d2) | |
d3 = k.activations.sigmoid(d3) | |
d4 = k.activations.sigmoid(d4) | |
e5 = k.activations.sigmoid(e5) | |
else: | |
d1 = k.activations.softmax(d1) | |
d2 = k.activations.softmax(d2) | |
d3 = k.activations.softmax(d3) | |
d4 = k.activations.softmax(d4) | |
e5 = k.activations.softmax(e5) | |
model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup') | |
if(pretrained_weights): | |
model.load_weights(pretrained_weights) | |
return model | |
""" UNet_3Plus with Deep Supervison and Classification Guided Module""" | |
def UNet_3Plus_DeepSup_CGM(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None): | |
filters = [64, 128, 256, 512, 1024] | |
input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3 | |
""" Encoder""" | |
# block 1 | |
e1 = conv_block(input_layer, filters[0]) # 320*320*64 | |
# block 2 | |
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64 | |
e2 = conv_block(e2, filters[1]) # 160*160*128 | |
# block 3 | |
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128 | |
e3 = conv_block(e3, filters[2]) # 80*80*256 | |
# block 4 | |
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256 | |
e4 = conv_block(e4, filters[3]) # 40*40*512 | |
# block 5, bottleneck layer | |
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512 | |
e5 = conv_block(e5, filters[4]) # 20*20*1024 | |
""" Classification Guided Module. Part 1""" | |
cls = k.layers.Dropout(rate=0.5)(e5) | |
cls = k.layers.Conv2D(2, kernel_size=(1, 1), padding="same", strides=(1, 1))(cls) | |
cls = k.layers.GlobalMaxPooling2D()(cls) | |
cls = k.activations.sigmoid(cls) | |
cls = tf.argmax(cls, axis=-1) | |
cls = cls[..., tf.newaxis] | |
cls = tf.cast(cls, dtype=tf.float32, ) | |
""" Decoder """ | |
cat_channels = filters[0] | |
cat_blocks = len(filters) | |
upsample_channels = cat_blocks * cat_channels | |
""" d4 """ | |
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64 | |
e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64 | |
e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128 | |
e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64 | |
e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256 | |
e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64 | |
e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64 | |
e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256 | |
e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64 | |
d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4]) | |
d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320 | |
""" d3 """ | |
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64 | |
e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64 | |
e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256 | |
e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64 | |
e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64 | |
e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320 | |
e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 | |
e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320 | |
e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64 | |
d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3]) | |
d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320 | |
""" d2 """ | |
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64 | |
e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64 | |
e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64 | |
d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320 | |
d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320 | |
d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320 | |
e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2]) | |
d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320 | |
""" d1 """ | |
e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64 | |
d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320 | |
d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64 | |
d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320 | |
d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320 | |
d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320 | |
e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64 | |
d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ]) | |
d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320 | |
""" Deep Supervision Part""" | |
# last layer does not have batchnorm and relu | |
d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False) | |
# d1 = no need for upsampling | |
d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) | |
d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) | |
d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) | |
e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) | |
""" Classification Guided Module. Part 2""" | |
d1 = dotProduct(d1, cls) | |
d2 = dotProduct(d2, cls) | |
d3 = dotProduct(d3, cls) | |
d4 = dotProduct(d4, cls) | |
e5 = dotProduct(e5, cls) | |
if OUTPUT_CHANNELS == 1: | |
d1 = k.activations.sigmoid(d1) | |
d2 = k.activations.sigmoid(d2) | |
d3 = k.activations.sigmoid(d3) | |
d4 = k.activations.sigmoid(d4) | |
e5 = k.activations.sigmoid(e5) | |
else: | |
d1 = k.activations.softmax(d1) | |
d2 = k.activations.softmax(d2) | |
d3 = k.activations.softmax(d3) | |
d4 = k.activations.softmax(d4) | |
e5 = k.activations.softmax(e5) | |
model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup_CGM') | |
if(pretrained_weights): | |
model.load_weights(pretrained_weights) | |
return model |