bumble-bee commited on
Commit
3368fe8
1 Parent(s): 123a5ad

Add unet versions(v0, v1, v2, v3)

Browse files
app.py CHANGED
@@ -1,7 +1,216 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio_client import Client
3
+ import numpy as np
4
+ #import torch
5
+ import requests
6
+ from PIL import Image
7
+ #from torchvision import transforms
8
+ from predict_unet import predict_model
9
 
 
 
10
 
11
+ # device = torch.device(
12
+ # "cuda"
13
+ # if torch.cuda.is_available()
14
+ # else "mps"
15
+ # if torch.backends.mps.is_available()
16
+ # else "cpu"
17
+ # )
18
+
19
+ title = "<center><strong><font size='8'> Medical Image Segmentation with UNet </font></strong></center>"
20
+
21
+ examples = [["examples/50494616.jpg"], ["examples/50494676.jpg"], ["examples/56399783.jpg"],
22
+ ["examples/56399789.jpg"], ["examples/56399831.jpg"], ["examples/56399959.jpg"],
23
+ ["examples/56400014.jpg"], ["examples/56400119.jpg"],
24
+ ["examples/56481903.jpg"], ["examples/70749195.jpg"]]
25
+
26
+ def run_unetv0(input):
27
+ output = predict_model(input, "v0")
28
+ normalized_output = np.clip(output, 0, 1)
29
+ return normalized_output
30
+
31
+ def run_unetv1(input):
32
+ output = predict_model(input, "v1")
33
+ normalized_output = np.clip(output, 0, 1)
34
+ return normalized_output
35
+
36
+ def run_unetv2(input):
37
+ output = predict_model(input, "v2")
38
+ normalized_output = np.clip(output, 0, 1)
39
+ return normalized_output
40
+
41
+ def run_unetv3(input):
42
+ output = predict_model(input, "v3")
43
+ normalized_output = np.clip(output, 0, 1)
44
+ return normalized_output
45
+
46
+
47
+ input_img_v0 = gr.Image(label="Input", type='numpy')
48
+ segm_img_v0 = gr.Image(label="Segmented Image")
49
+
50
+ input_img_v1 = gr.Image(label="Input", type='numpy')
51
+ segm_img_v1 = gr.Image(label="Segmented Image")
52
+
53
+ input_img_v2 = gr.Image(label="Input", type='numpy')
54
+ segm_img_v2 = gr.Image(label="Segmented Image")
55
+
56
+ input_img_v3 = gr.Image(label="Input", type='numpy')
57
+ segm_img_v3 = gr.Image(label="Segmented Image")
58
+
59
+
60
+ with gr.Blocks(title='UNet examples') as demo:
61
+ # v0: regular UNet
62
+ with gr.Tab("Regular UNet (v0)"):
63
+ # display input image and segmented image
64
+ with gr.Row(variant="panel"):
65
+ with gr.Column(scale=1):
66
+ input_img_v0.render()
67
+
68
+ with gr.Column(scale=1):
69
+ segm_img_v0.render()
70
+
71
+ # submit and clear
72
+ with gr.Row():
73
+ with gr.Column():
74
+ segment_btn_v0 = gr.Button("Run Segmentation", variant='primary')
75
+ clear_btn_v0 = gr.Button("Clear", variant="secondary")
76
+
77
+ # load examples
78
+ gr.Markdown("Try some of the examples below")
79
+ gr.Examples(examples=examples,
80
+ inputs=[input_img_v0],
81
+ outputs=segm_img_v0,
82
+ fn=run_unetv0,
83
+ cache_examples=False,
84
+ examples_per_page=5)
85
+
86
+ # just a placeholder for second column
87
+ with gr.Column():
88
+ gr.Markdown("")
89
+
90
+ segment_btn_v0.click(run_unetv0,
91
+ inputs=[
92
+ input_img_v0,
93
+ ],
94
+ outputs=segm_img_v0)
95
+
96
+
97
+ # v1: UNet3+
98
+ with gr.Tab("UNet3+ (v1)"):
99
+ # display input image and segmented image
100
+ with gr.Row(variant="panel"):
101
+ with gr.Column(scale=1):
102
+ input_img_v1.render()
103
+
104
+ with gr.Column(scale=1):
105
+ segm_img_v1.render()
106
+
107
+ # submit and clear
108
+ with gr.Row():
109
+ with gr.Column():
110
+ segment_btn_v1 = gr.Button("Run Segmentation", variant='primary')
111
+ clear_btn_v1 = gr.Button("Clear", variant="secondary")
112
+
113
+ # load examples
114
+ gr.Markdown("Try some of the examples below")
115
+ gr.Examples(examples=examples,
116
+ inputs=[input_img_v1],
117
+ outputs=segm_img_v1,
118
+ fn=run_unetv1,
119
+ cache_examples=False,
120
+ examples_per_page=5)
121
+
122
+ # just a placeholder for second column
123
+ with gr.Column():
124
+ gr.Markdown("")
125
+
126
+ segment_btn_v1.click(run_unetv1,
127
+ inputs=[
128
+ input_img_v1,
129
+ ],
130
+ outputs=segm_img_v1)
131
+
132
+
133
+ # v2: UNet3+ with deep supervision
134
+ with gr.Tab("UNet3+(v2) with deep supervision"):
135
+ # display input image and segmented image
136
+ with gr.Row(variant="panel"):
137
+ with gr.Column(scale=1):
138
+ input_img_v2.render()
139
+
140
+ with gr.Column(scale=1):
141
+ segm_img_v2.render()
142
+
143
+ # submit and clear
144
+ with gr.Row():
145
+ with gr.Column():
146
+ segment_btn_v2 = gr.Button("Run Segmentation", variant='primary')
147
+ clear_btn_v2 = gr.Button("Clear", variant="secondary")
148
+
149
+ # load examples
150
+ gr.Markdown("Try some of the examples below")
151
+ gr.Examples(examples=examples,
152
+ inputs=[input_img_v2],
153
+ outputs=segm_img_v2,
154
+ fn=run_unetv2,
155
+ cache_examples=False,
156
+ examples_per_page=5)
157
+
158
+ # just a placeholder for second column
159
+ with gr.Column():
160
+ gr.Markdown("")
161
+
162
+ segment_btn_v2.click(run_unetv2,
163
+ inputs=[
164
+ input_img_v2,
165
+ ],
166
+ outputs=segm_img_v2)
167
+
168
+
169
+ # v3: UNet3+ with deep supervision and cgm
170
+ with gr.Tab("UNet3+(v3) with deep supervision and cgm"):
171
+ # display input image and segmented image
172
+ with gr.Row(variant="panel"):
173
+ with gr.Column(scale=1):
174
+ input_img_v3.render()
175
+
176
+ with gr.Column(scale=1):
177
+ segm_img_v3.render()
178
+
179
+ # submit and clear
180
+ with gr.Row():
181
+ with gr.Column():
182
+ segment_btn_v3 = gr.Button("Run Segmentation", variant='primary')
183
+ clear_btn_v3 = gr.Button("Clear", variant="secondary")
184
+
185
+ # load examples
186
+ gr.Markdown("Try some of the examples below")
187
+ gr.Examples(examples=examples,
188
+ inputs=[input_img_v3],
189
+ outputs=segm_img_v3,
190
+ fn=run_unetv3,
191
+ cache_examples=False,
192
+ examples_per_page=5)
193
+
194
+ # just a placeholder for second column
195
+ with gr.Column():
196
+ gr.Markdown("")
197
+
198
+ segment_btn_v3.click(run_unetv3,
199
+ inputs=[
200
+ input_img_v3,
201
+ ],
202
+ outputs=segm_img_v3)
203
+
204
+
205
+ def clear():
206
+ return None, None
207
+
208
+ clear_btn_v0.click(clear, outputs=[input_img_v0, segm_img_v0])
209
+ clear_btn_v1.click(clear, outputs=[input_img_v1, segm_img_v1])
210
+ clear_btn_v2.click(clear, outputs=[input_img_v2, segm_img_v2])
211
+ clear_btn_v3.click(clear, outputs=[input_img_v3, segm_img_v3])
212
+
213
+
214
+ demo.queue()
215
+ demo.launch()
216
+
examples/50494616.jpg ADDED
examples/50494676.jpg ADDED
examples/56399783.jpg ADDED
examples/56399789.jpg ADDED
examples/56399831.jpg ADDED
examples/56399959.jpg ADDED
examples/56400014.jpg ADDED
examples/56400119.jpg ADDED
examples/56481903.jpg ADDED
examples/70749195.jpg ADDED
predict_unet.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import skimage.io as skio
5
+ import skimage.transform as trans
6
+ from skimage.color import rgb2gray
7
+ from matplotlib import pyplot as plt
8
+ import sys
9
+
10
+ sys.path.append("/panfs/jay/groups/29/umii/mo000007/zooniverse/UNet")
11
+
12
+ from utils import *
13
+ from unet import unet
14
+ from unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM
15
+
16
+
17
+ def predict_model(input, unet_type):
18
+ model_path = "/home/umii/mo000007/zooniverse/UNet/trained_models"
19
+ h, w = 256, 256
20
+ input_shape = [h, w, 1]
21
+ output_channels = 1
22
+ batch_size = 1
23
+
24
+ # convert image into numpy array and reshape it into model's input size
25
+ img = trans.resize(input, (w, h))
26
+ result_img = img.copy()
27
+ img = rgb2gray(img).reshape(1, h, w, 1)
28
+
29
+ # Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm
30
+ if unet_type == 'v0':
31
+ # load the pretrained model
32
+ model_name = "unetv0_sgd500_neptune"
33
+ model_file = os.path.join(model_path, model_name + ".hdf5")
34
+ model = unet(model_file)
35
+ elif unet_type == 'v1':
36
+ # load the pretrained model
37
+ model_name = "unetv1_sgd500_neptune"
38
+ model_file = os.path.join(model_path, model_name + ".hdf5")
39
+ model = UNet_3Plus(input_shape, output_channels, model_file)
40
+ elif unet_type == 'v2':
41
+ # load the pretrained model
42
+ model_name = "unetv2_sgd500_neptune"
43
+ model_file = os.path.join(model_path, model_name + ".hdf5")
44
+ model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file)
45
+ else:
46
+ # load the pretrained model
47
+ model_name = "unetv3_sgd500_neptune"
48
+ model_file = os.path.join(model_path, model_name + ".hdf5")
49
+ model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file)
50
+
51
+ # Predict and save the results as numpy array
52
+ results = model.predict(img)
53
+
54
+ # Preprocess the prediction from the model depending on the model
55
+ if unet_type == 'v2' or unet_type == 'v3':
56
+ pred = np.copy(results[0])
57
+ else:
58
+ pred = np.copy(results)
59
+ pred[pred >= 0.5] = 1
60
+ pred[pred < 0.5] = 0
61
+ output = np.array(pred[0][:,:,0])
62
+
63
+ # visualize the output mask
64
+ seg_color = [0, 0, 255]
65
+ masked = output != 0
66
+ result_img[masked] = seg_color
67
+
68
+ return result_img
69
+