bumble-bee commited on
Commit
6bf4d42
·
1 Parent(s): 80a7fd8

added unet files

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +0 -8
  3. predict_unet.py +3 -10
  4. unet/unet.py +60 -0
  5. unet/unet_3plus.py +440 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ unet/__pycache__
app.py CHANGED
@@ -8,14 +8,6 @@ from PIL import Image
8
  from predict_unet import predict_model
9
 
10
 
11
- # device = torch.device(
12
- # "cuda"
13
- # if torch.cuda.is_available()
14
- # else "mps"
15
- # if torch.backends.mps.is_available()
16
- # else "cpu"
17
- # )
18
-
19
  title = "<center><strong><font size='8'> Medical Image Segmentation with UNet </font></strong></center>"
20
 
21
  examples = [["examples/50494616.jpg"], ["examples/50494676.jpg"], ["examples/56399783.jpg"],
 
8
  from predict_unet import predict_model
9
 
10
 
 
 
 
 
 
 
 
 
11
  title = "<center><strong><font size='8'> Medical Image Segmentation with UNet </font></strong></center>"
12
 
13
  examples = [["examples/50494616.jpg"], ["examples/50494676.jpg"], ["examples/56399783.jpg"],
predict_unet.py CHANGED
@@ -1,21 +1,14 @@
1
 
2
  import os
3
  import numpy as np
4
- import skimage.io as skio
5
  import skimage.transform as trans
6
  from skimage.color import rgb2gray
7
- from matplotlib import pyplot as plt
8
- import sys
9
-
10
- sys.path.append("/panfs/jay/groups/29/umii/mo000007/zooniverse/UNet")
11
-
12
- from utils import *
13
- from unet import unet
14
- from unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM
15
 
16
 
17
  def predict_model(input, unet_type):
18
- model_path = "/home/umii/mo000007/zooniverse/UNet/trained_models"
19
  h, w = 256, 256
20
  input_shape = [h, w, 1]
21
  output_channels = 1
 
1
 
2
  import os
3
  import numpy as np
 
4
  import skimage.transform as trans
5
  from skimage.color import rgb2gray
6
+ from unet.unet import unet
7
+ from unet.unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM
 
 
 
 
 
 
8
 
9
 
10
  def predict_model(input, unet_type):
11
+ model_path = "unet/trained_models"
12
  h, w = 256, 256
13
  input_shape = [h, w, 1]
14
  output_channels = 1
unet/unet.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build U-Net model
2
+ import tensorflow as tf
3
+ import tensorflow.keras.layers as layers
4
+ import tensorflow.keras.models as models
5
+ import tensorflow.keras.metrics as metrics
6
+ #from keras import backend as keras
7
+
8
+
9
+ def unet(pretrained_weights = None, input_size = (256,256,1)):
10
+ inputs = layers.Input(input_size)
11
+ conv1 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
12
+ conv1 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
13
+ pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
14
+
15
+ conv2 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
16
+ conv2 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
17
+ pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
18
+
19
+ conv3 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
20
+ conv3 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
21
+ pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
22
+
23
+ conv4 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
24
+ conv4 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
25
+ drop4 = layers.Dropout(0.5)(conv4)
26
+ pool4 = layers.MaxPooling2D(pool_size=(2, 2))(drop4)
27
+
28
+ conv5 = layers.Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
29
+ conv5 = layers.Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
30
+ drop5 = layers.Dropout(0.5)(conv5)
31
+
32
+ up6 = layers.Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(layers.UpSampling2D(size = (2,2))(drop5))
33
+ merge6 = layers.concatenate([drop4,up6], axis = 3)
34
+ conv6 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
35
+ conv6 = layers.Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
36
+
37
+ up7 = layers.Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(layers.UpSampling2D(size = (2,2))(conv6))
38
+ merge7 = layers.concatenate([conv3,up7], axis = 3)
39
+ conv7 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
40
+ conv7 = layers.Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
41
+
42
+ up8 = layers.Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(layers.UpSampling2D(size = (2,2))(conv7))
43
+ merge8 = layers.concatenate([conv2,up8], axis = 3)
44
+ conv8 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
45
+ conv8 = layers.Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
46
+
47
+ up9 = layers.Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(layers.UpSampling2D(size = (2,2))(conv8))
48
+ merge9 = layers.concatenate([conv1,up9], axis = 3)
49
+ conv9 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
50
+ conv9 = layers.Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
51
+ conv9 = layers.Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
52
+
53
+ conv10 = layers.Conv2D(1, 1, activation = 'sigmoid')(conv9)
54
+
55
+ model = models.Model(inputs=inputs, outputs=conv10)
56
+
57
+ if(pretrained_weights):
58
+ model.load_weights(pretrained_weights)
59
+
60
+ return model
unet/unet_3plus.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.keras as k
3
+
4
+
5
+ # Model Architecture
6
+ def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same',
7
+ is_bn=True, is_relu=True, n=2):
8
+ """ Custom function for conv2d:
9
+ Apply 3*3 convolutions with BN and relu.
10
+ """
11
+ for i in range(1, n + 1):
12
+ x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size,
13
+ padding=padding, strides=strides,
14
+ kernel_regularizer=tf.keras.regularizers.l2(1e-4),
15
+ kernel_initializer=k.initializers.he_normal(seed=5))(x)
16
+ if is_bn:
17
+ x = k.layers.BatchNormalization()(x)
18
+ if is_relu:
19
+ x = k.activations.relu(x)
20
+
21
+ return x
22
+
23
+
24
+ def dotProduct(seg, cls):
25
+ B, H, W, N = k.backend.int_shape(seg)
26
+ seg = tf.reshape(seg, [-1, H * W, N])
27
+ final = tf.einsum("ijk,ik->ijk", seg, cls)
28
+ final = tf.reshape(final, [-1, H, W, N])
29
+ return final
30
+
31
+
32
+ """ UNet_3Plus """
33
+ def UNet_3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None):
34
+ filters = [64, 128, 256, 512, 1024]
35
+
36
+ input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3
37
+
38
+ """ Encoder"""
39
+ # block 1
40
+ e1 = conv_block(input_layer, filters[0]) # 320*320*64
41
+
42
+ # block 2
43
+ e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
44
+ e2 = conv_block(e2, filters[1]) # 160*160*128
45
+
46
+ # block 3
47
+ e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
48
+ e3 = conv_block(e3, filters[2]) # 80*80*256
49
+
50
+ # block 4
51
+ e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
52
+ e4 = conv_block(e4, filters[3]) # 40*40*512
53
+
54
+ # block 5
55
+ # bottleneck layer
56
+ e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
57
+ e5 = conv_block(e5, filters[4]) # 20*20*1024
58
+
59
+ """ Decoder """
60
+ cat_channels = filters[0]
61
+ cat_blocks = len(filters)
62
+ upsample_channels = cat_blocks * cat_channels
63
+
64
+ """ d4 """
65
+ e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
66
+ e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
67
+
68
+ e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
69
+ e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
70
+
71
+ e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
72
+ e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
73
+
74
+ e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
75
+
76
+ e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
77
+ e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
78
+
79
+ d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
80
+ d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
81
+
82
+ """ d3 """
83
+ e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
84
+ e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
85
+
86
+ e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
87
+ e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
88
+
89
+ e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
90
+
91
+ e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
92
+ e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
93
+
94
+ e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
95
+ e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
96
+
97
+ d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
98
+ d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
99
+
100
+ """ d2 """
101
+ e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
102
+ e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
103
+
104
+ e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
105
+
106
+ d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
107
+ d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
108
+
109
+ d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
110
+ d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
111
+
112
+ e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
113
+ e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
114
+
115
+ d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
116
+ d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
117
+
118
+ """ d1 """
119
+ e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
120
+
121
+ d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
122
+ d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
123
+
124
+ d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
125
+ d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
126
+
127
+ d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
128
+ d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
129
+
130
+ e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
131
+ e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
132
+
133
+ d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
134
+ d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
135
+
136
+ # last layer does not have batchnorm and relu
137
+ d = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
138
+
139
+ if OUTPUT_CHANNELS == 1:
140
+ output = k.activations.sigmoid(d)
141
+ else:
142
+ output = k.activations.softmax(d)
143
+
144
+ model = tf.keras.Model(inputs=input_layer, outputs=output, name='UNet_3Plus')
145
+ if(pretrained_weights):
146
+ model.load_weights(pretrained_weights)
147
+
148
+ return model
149
+
150
+
151
+ """ UNet_3Plus with Deep Supervison"""
152
+ def UNet_3Plus_DeepSup(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None):
153
+ filters = [64, 128, 256, 512, 1024]
154
+
155
+ input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3
156
+
157
+ """ Encoder"""
158
+ # block 1
159
+ e1 = conv_block(input_layer, filters[0]) # 320*320*64
160
+
161
+ # block 2
162
+ e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
163
+ e2 = conv_block(e2, filters[1]) # 160*160*128
164
+
165
+ # block 3
166
+ e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
167
+ e3 = conv_block(e3, filters[2]) # 80*80*256
168
+
169
+ # block 4
170
+ e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
171
+ e4 = conv_block(e4, filters[3]) # 40*40*512
172
+
173
+ # block 5
174
+ # bottleneck layer
175
+ e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
176
+ e5 = conv_block(e5, filters[4]) # 20*20*1024
177
+
178
+ """ Decoder """
179
+ cat_channels = filters[0]
180
+ cat_blocks = len(filters)
181
+ upsample_channels = cat_blocks * cat_channels
182
+
183
+ """ d4 """
184
+ e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
185
+ e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
186
+
187
+ e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
188
+ e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
189
+
190
+ e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
191
+ e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
192
+
193
+ e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
194
+
195
+ e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
196
+ e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
197
+
198
+ d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
199
+ d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
200
+
201
+ """ d3 """
202
+ e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
203
+ e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
204
+
205
+ e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
206
+ e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
207
+
208
+ e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
209
+
210
+ e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
211
+ e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
212
+
213
+ e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
214
+ e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
215
+
216
+ d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
217
+ d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
218
+
219
+ """ d2 """
220
+ e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
221
+ e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
222
+
223
+ e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
224
+
225
+ d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
226
+ d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
227
+
228
+ d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
229
+ d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
230
+
231
+ e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
232
+ e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
233
+
234
+ d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
235
+ d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
236
+
237
+ """ d1 """
238
+ e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
239
+
240
+ d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
241
+ d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
242
+
243
+ d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
244
+ d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
245
+
246
+ d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
247
+ d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
248
+
249
+ e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
250
+ e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
251
+
252
+ d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
253
+ d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
254
+
255
+ """ Deep Supervision Part"""
256
+ # last layer does not have batchnorm and relu
257
+ d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
258
+ d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
259
+ d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
260
+ d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
261
+ e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
262
+
263
+ # d1 = no need for upsampling
264
+ d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2)
265
+ d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3)
266
+ d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4)
267
+ e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)
268
+
269
+ if OUTPUT_CHANNELS == 1:
270
+ d1 = k.activations.sigmoid(d1)
271
+ d2 = k.activations.sigmoid(d2)
272
+ d3 = k.activations.sigmoid(d3)
273
+ d4 = k.activations.sigmoid(d4)
274
+ e5 = k.activations.sigmoid(e5)
275
+ else:
276
+ d1 = k.activations.softmax(d1)
277
+ d2 = k.activations.softmax(d2)
278
+ d3 = k.activations.softmax(d3)
279
+ d4 = k.activations.softmax(d4)
280
+ e5 = k.activations.softmax(e5)
281
+
282
+ model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup')
283
+
284
+ if(pretrained_weights):
285
+ model.load_weights(pretrained_weights)
286
+
287
+ return model
288
+
289
+
290
+ """ UNet_3Plus with Deep Supervison and Classification Guided Module"""
291
+ def UNet_3Plus_DeepSup_CGM(INPUT_SHAPE, OUTPUT_CHANNELS, pretrained_weights = None):
292
+ filters = [64, 128, 256, 512, 1024]
293
+
294
+ input_layer = k.layers.Input(shape=INPUT_SHAPE, name="input_layer") # 320*320*3
295
+
296
+ """ Encoder"""
297
+ # block 1
298
+ e1 = conv_block(input_layer, filters[0]) # 320*320*64
299
+
300
+ # block 2
301
+ e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
302
+ e2 = conv_block(e2, filters[1]) # 160*160*128
303
+
304
+ # block 3
305
+ e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
306
+ e3 = conv_block(e3, filters[2]) # 80*80*256
307
+
308
+ # block 4
309
+ e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
310
+ e4 = conv_block(e4, filters[3]) # 40*40*512
311
+
312
+ # block 5, bottleneck layer
313
+ e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
314
+ e5 = conv_block(e5, filters[4]) # 20*20*1024
315
+
316
+ """ Classification Guided Module. Part 1"""
317
+ cls = k.layers.Dropout(rate=0.5)(e5)
318
+ cls = k.layers.Conv2D(2, kernel_size=(1, 1), padding="same", strides=(1, 1))(cls)
319
+ cls = k.layers.GlobalMaxPooling2D()(cls)
320
+ cls = k.activations.sigmoid(cls)
321
+ cls = tf.argmax(cls, axis=-1)
322
+ cls = cls[..., tf.newaxis]
323
+ cls = tf.cast(cls, dtype=tf.float32, )
324
+
325
+ """ Decoder """
326
+ cat_channels = filters[0]
327
+ cat_blocks = len(filters)
328
+ upsample_channels = cat_blocks * cat_channels
329
+
330
+ """ d4 """
331
+ e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
332
+ e1_d4 = conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64
333
+
334
+ e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
335
+ e2_d4 = conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64
336
+
337
+ e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
338
+ e3_d4 = conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64
339
+
340
+ e4_d4 = conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64
341
+
342
+ e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
343
+ e5_d4 = conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64
344
+
345
+ d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
346
+ d4 = conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320
347
+
348
+ """ d3 """
349
+ e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
350
+ e1_d3 = conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64
351
+
352
+ e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
353
+ e2_d3 = conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64
354
+
355
+ e3_d3 = conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64
356
+
357
+ e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
358
+ e4_d3 = conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
359
+
360
+ e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
361
+ e5_d3 = conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64
362
+
363
+ d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
364
+ d3 = conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320
365
+
366
+ """ d2 """
367
+ e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
368
+ e1_d2 = conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64
369
+
370
+ e2_d2 = conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64
371
+
372
+ d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
373
+ d3_d2 = conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
374
+
375
+ d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
376
+ d4_d2 = conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
377
+
378
+ e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
379
+ e5_d2 = conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64
380
+
381
+ d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
382
+ d2 = conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320
383
+
384
+ """ d1 """
385
+ e1_d1 = conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64
386
+
387
+ d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
388
+ d2_d1 = conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64
389
+
390
+ d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
391
+ d3_d1 = conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
392
+
393
+ d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
394
+ d4_d1 = conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
395
+
396
+ e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
397
+ e5_d1 = conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64
398
+
399
+ d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
400
+ d1 = conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320
401
+
402
+ """ Deep Supervision Part"""
403
+ # last layer does not have batchnorm and relu
404
+ d1 = conv_block(d1, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
405
+ d2 = conv_block(d2, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
406
+ d3 = conv_block(d3, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
407
+ d4 = conv_block(d4, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
408
+ e5 = conv_block(e5, OUTPUT_CHANNELS, n=1, is_bn=False, is_relu=False)
409
+
410
+ # d1 = no need for upsampling
411
+ d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2)
412
+ d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3)
413
+ d4 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4)
414
+ e5 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5)
415
+
416
+ """ Classification Guided Module. Part 2"""
417
+ d1 = dotProduct(d1, cls)
418
+ d2 = dotProduct(d2, cls)
419
+ d3 = dotProduct(d3, cls)
420
+ d4 = dotProduct(d4, cls)
421
+ e5 = dotProduct(e5, cls)
422
+
423
+ if OUTPUT_CHANNELS == 1:
424
+ d1 = k.activations.sigmoid(d1)
425
+ d2 = k.activations.sigmoid(d2)
426
+ d3 = k.activations.sigmoid(d3)
427
+ d4 = k.activations.sigmoid(d4)
428
+ e5 = k.activations.sigmoid(e5)
429
+ else:
430
+ d1 = k.activations.softmax(d1)
431
+ d2 = k.activations.softmax(d2)
432
+ d3 = k.activations.softmax(d3)
433
+ d4 = k.activations.softmax(d4)
434
+ e5 = k.activations.softmax(e5)
435
+
436
+ model = tf.keras.Model(inputs=input_layer, outputs=[d1, d2, d3, d4, e5], name='UNet_3Plus_DeepSup_CGM')
437
+ if(pretrained_weights):
438
+ model.load_weights(pretrained_weights)
439
+
440
+ return model