TornikeO commited on
Commit
b771b1c
1 Parent(s): 839dd48

Re-add models dir

Browse files
Files changed (3) hide show
  1. .gitignore +1 -2
  2. models/__init__.py +1 -0
  3. models/isnet.py +610 -0
.gitignore CHANGED
@@ -6,5 +6,4 @@ tmp/
6
  # *.png
7
  *.db
8
  __pycache__
9
- saved_models
10
- models
 
6
  # *.png
7
  *.db
8
  __pycache__
9
+ saved_models
 
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from models.isnet import ISNetGTEncoder, ISNetDIS
models/isnet.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+
6
+
7
+ bce_loss = nn.BCELoss(size_average=True)
8
+ def muti_loss_fusion(preds, target):
9
+ loss0 = 0.0
10
+ loss = 0.0
11
+
12
+ for i in range(0,len(preds)):
13
+ # print("i: ", i, preds[i].shape)
14
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
15
+ # tmp_target = _upsample_like(target,preds[i])
16
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
17
+ loss = loss + bce_loss(preds[i],tmp_target)
18
+ else:
19
+ loss = loss + bce_loss(preds[i],target)
20
+ if(i==0):
21
+ loss0 = loss
22
+ return loss0, loss
23
+
24
+ fea_loss = nn.MSELoss(size_average=True)
25
+ kl_loss = nn.KLDivLoss(size_average=True)
26
+ l1_loss = nn.L1Loss(size_average=True)
27
+ smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
28
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
29
+ loss0 = 0.0
30
+ loss = 0.0
31
+
32
+ for i in range(0,len(preds)):
33
+ # print("i: ", i, preds[i].shape)
34
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
35
+ # tmp_target = _upsample_like(target,preds[i])
36
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
37
+ loss = loss + bce_loss(preds[i],tmp_target)
38
+ else:
39
+ loss = loss + bce_loss(preds[i],target)
40
+ if(i==0):
41
+ loss0 = loss
42
+
43
+ for i in range(0,len(dfs)):
44
+ if(mode=='MSE'):
45
+ loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
46
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
47
+ elif(mode=='KL'):
48
+ loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
49
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
50
+ elif(mode=='MAE'):
51
+ loss = loss + l1_loss(dfs[i],fs[i])
52
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
53
+ elif(mode=='SmoothL1'):
54
+ loss = loss + smooth_l1_loss(dfs[i],fs[i])
55
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
56
+
57
+ return loss0, loss
58
+
59
+ class REBNCONV(nn.Module):
60
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
61
+ super(REBNCONV,self).__init__()
62
+
63
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
64
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
65
+ self.relu_s1 = nn.ReLU(inplace=True)
66
+
67
+ def forward(self,x):
68
+
69
+ hx = x
70
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
71
+
72
+ return xout
73
+
74
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
75
+ def _upsample_like(src,tar):
76
+
77
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
78
+
79
+ return src
80
+
81
+
82
+ ### RSU-7 ###
83
+ class RSU7(nn.Module):
84
+
85
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
86
+ super(RSU7,self).__init__()
87
+
88
+ self.in_ch = in_ch
89
+ self.mid_ch = mid_ch
90
+ self.out_ch = out_ch
91
+
92
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
93
+
94
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
95
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
96
+
97
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
98
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
99
+
100
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
101
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
102
+
103
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
104
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
105
+
106
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
107
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
108
+
109
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
110
+
111
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
112
+
113
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
114
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
115
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
116
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
117
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
118
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
119
+
120
+ def forward(self,x):
121
+ b, c, h, w = x.shape
122
+
123
+ hx = x
124
+ hxin = self.rebnconvin(hx)
125
+
126
+ hx1 = self.rebnconv1(hxin)
127
+ hx = self.pool1(hx1)
128
+
129
+ hx2 = self.rebnconv2(hx)
130
+ hx = self.pool2(hx2)
131
+
132
+ hx3 = self.rebnconv3(hx)
133
+ hx = self.pool3(hx3)
134
+
135
+ hx4 = self.rebnconv4(hx)
136
+ hx = self.pool4(hx4)
137
+
138
+ hx5 = self.rebnconv5(hx)
139
+ hx = self.pool5(hx5)
140
+
141
+ hx6 = self.rebnconv6(hx)
142
+
143
+ hx7 = self.rebnconv7(hx6)
144
+
145
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
146
+ hx6dup = _upsample_like(hx6d,hx5)
147
+
148
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
149
+ hx5dup = _upsample_like(hx5d,hx4)
150
+
151
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
152
+ hx4dup = _upsample_like(hx4d,hx3)
153
+
154
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
155
+ hx3dup = _upsample_like(hx3d,hx2)
156
+
157
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
158
+ hx2dup = _upsample_like(hx2d,hx1)
159
+
160
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
161
+
162
+ return hx1d + hxin
163
+
164
+
165
+ ### RSU-6 ###
166
+ class RSU6(nn.Module):
167
+
168
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
169
+ super(RSU6,self).__init__()
170
+
171
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
172
+
173
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
174
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
175
+
176
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
177
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
178
+
179
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
180
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
181
+
182
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
183
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+
187
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
188
+
189
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
190
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
191
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
192
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
193
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
194
+
195
+ def forward(self,x):
196
+
197
+ hx = x
198
+
199
+ hxin = self.rebnconvin(hx)
200
+
201
+ hx1 = self.rebnconv1(hxin)
202
+ hx = self.pool1(hx1)
203
+
204
+ hx2 = self.rebnconv2(hx)
205
+ hx = self.pool2(hx2)
206
+
207
+ hx3 = self.rebnconv3(hx)
208
+ hx = self.pool3(hx3)
209
+
210
+ hx4 = self.rebnconv4(hx)
211
+ hx = self.pool4(hx4)
212
+
213
+ hx5 = self.rebnconv5(hx)
214
+
215
+ hx6 = self.rebnconv6(hx5)
216
+
217
+
218
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
219
+ hx5dup = _upsample_like(hx5d,hx4)
220
+
221
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
222
+ hx4dup = _upsample_like(hx4d,hx3)
223
+
224
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
225
+ hx3dup = _upsample_like(hx3d,hx2)
226
+
227
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
228
+ hx2dup = _upsample_like(hx2d,hx1)
229
+
230
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
231
+
232
+ return hx1d + hxin
233
+
234
+ ### RSU-5 ###
235
+ class RSU5(nn.Module):
236
+
237
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
238
+ super(RSU5,self).__init__()
239
+
240
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
241
+
242
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
243
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
244
+
245
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
246
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
247
+
248
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
249
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
250
+
251
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
252
+
253
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
254
+
255
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
256
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
257
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
258
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
259
+
260
+ def forward(self,x):
261
+
262
+ hx = x
263
+
264
+ hxin = self.rebnconvin(hx)
265
+
266
+ hx1 = self.rebnconv1(hxin)
267
+ hx = self.pool1(hx1)
268
+
269
+ hx2 = self.rebnconv2(hx)
270
+ hx = self.pool2(hx2)
271
+
272
+ hx3 = self.rebnconv3(hx)
273
+ hx = self.pool3(hx3)
274
+
275
+ hx4 = self.rebnconv4(hx)
276
+
277
+ hx5 = self.rebnconv5(hx4)
278
+
279
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
280
+ hx4dup = _upsample_like(hx4d,hx3)
281
+
282
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
283
+ hx3dup = _upsample_like(hx3d,hx2)
284
+
285
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
286
+ hx2dup = _upsample_like(hx2d,hx1)
287
+
288
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
289
+
290
+ return hx1d + hxin
291
+
292
+ ### RSU-4 ###
293
+ class RSU4(nn.Module):
294
+
295
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
296
+ super(RSU4,self).__init__()
297
+
298
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
299
+
300
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
301
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
302
+
303
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
304
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
305
+
306
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
307
+
308
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
309
+
310
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
311
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
312
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
313
+
314
+ def forward(self,x):
315
+
316
+ hx = x
317
+
318
+ hxin = self.rebnconvin(hx)
319
+
320
+ hx1 = self.rebnconv1(hxin)
321
+ hx = self.pool1(hx1)
322
+
323
+ hx2 = self.rebnconv2(hx)
324
+ hx = self.pool2(hx2)
325
+
326
+ hx3 = self.rebnconv3(hx)
327
+
328
+ hx4 = self.rebnconv4(hx3)
329
+
330
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
331
+ hx3dup = _upsample_like(hx3d,hx2)
332
+
333
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
334
+ hx2dup = _upsample_like(hx2d,hx1)
335
+
336
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
337
+
338
+ return hx1d + hxin
339
+
340
+ ### RSU-4F ###
341
+ class RSU4F(nn.Module):
342
+
343
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
344
+ super(RSU4F,self).__init__()
345
+
346
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
347
+
348
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
349
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
350
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
351
+
352
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
353
+
354
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
355
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
356
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
357
+
358
+ def forward(self,x):
359
+
360
+ hx = x
361
+
362
+ hxin = self.rebnconvin(hx)
363
+
364
+ hx1 = self.rebnconv1(hxin)
365
+ hx2 = self.rebnconv2(hx1)
366
+ hx3 = self.rebnconv3(hx2)
367
+
368
+ hx4 = self.rebnconv4(hx3)
369
+
370
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
371
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
372
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
373
+
374
+ return hx1d + hxin
375
+
376
+
377
+ class myrebnconv(nn.Module):
378
+ def __init__(self, in_ch=3,
379
+ out_ch=1,
380
+ kernel_size=3,
381
+ stride=1,
382
+ padding=1,
383
+ dilation=1,
384
+ groups=1):
385
+ super(myrebnconv,self).__init__()
386
+
387
+ self.conv = nn.Conv2d(in_ch,
388
+ out_ch,
389
+ kernel_size=kernel_size,
390
+ stride=stride,
391
+ padding=padding,
392
+ dilation=dilation,
393
+ groups=groups)
394
+ self.bn = nn.BatchNorm2d(out_ch)
395
+ self.rl = nn.ReLU(inplace=True)
396
+
397
+ def forward(self,x):
398
+ return self.rl(self.bn(self.conv(x)))
399
+
400
+
401
+ class ISNetGTEncoder(nn.Module):
402
+
403
+ def __init__(self,in_ch=1,out_ch=1):
404
+ super(ISNetGTEncoder,self).__init__()
405
+
406
+ self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
407
+
408
+ self.stage1 = RSU7(16,16,64)
409
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
410
+
411
+ self.stage2 = RSU6(64,16,64)
412
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
413
+
414
+ self.stage3 = RSU5(64,32,128)
415
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
416
+
417
+ self.stage4 = RSU4(128,32,256)
418
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
419
+
420
+ self.stage5 = RSU4F(256,64,512)
421
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
422
+
423
+ self.stage6 = RSU4F(512,64,512)
424
+
425
+
426
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
427
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
428
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
429
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
430
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
431
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
432
+
433
+ def compute_loss(self, preds, targets):
434
+
435
+ return muti_loss_fusion(preds,targets)
436
+
437
+ def forward(self,x):
438
+
439
+ hx = x
440
+
441
+ hxin = self.conv_in(hx)
442
+ # hx = self.pool_in(hxin)
443
+
444
+ #stage 1
445
+ hx1 = self.stage1(hxin)
446
+ hx = self.pool12(hx1)
447
+
448
+ #stage 2
449
+ hx2 = self.stage2(hx)
450
+ hx = self.pool23(hx2)
451
+
452
+ #stage 3
453
+ hx3 = self.stage3(hx)
454
+ hx = self.pool34(hx3)
455
+
456
+ #stage 4
457
+ hx4 = self.stage4(hx)
458
+ hx = self.pool45(hx4)
459
+
460
+ #stage 5
461
+ hx5 = self.stage5(hx)
462
+ hx = self.pool56(hx5)
463
+
464
+ #stage 6
465
+ hx6 = self.stage6(hx)
466
+
467
+
468
+ #side output
469
+ d1 = self.side1(hx1)
470
+ d1 = _upsample_like(d1,x)
471
+
472
+ d2 = self.side2(hx2)
473
+ d2 = _upsample_like(d2,x)
474
+
475
+ d3 = self.side3(hx3)
476
+ d3 = _upsample_like(d3,x)
477
+
478
+ d4 = self.side4(hx4)
479
+ d4 = _upsample_like(d4,x)
480
+
481
+ d5 = self.side5(hx5)
482
+ d5 = _upsample_like(d5,x)
483
+
484
+ d6 = self.side6(hx6)
485
+ d6 = _upsample_like(d6,x)
486
+
487
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
488
+
489
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
490
+
491
+ class ISNetDIS(nn.Module):
492
+
493
+ def __init__(self,in_ch=3,out_ch=1):
494
+ super(ISNetDIS,self).__init__()
495
+
496
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
497
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
498
+
499
+ self.stage1 = RSU7(64,32,64)
500
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
501
+
502
+ self.stage2 = RSU6(64,32,128)
503
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
504
+
505
+ self.stage3 = RSU5(128,64,256)
506
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
507
+
508
+ self.stage4 = RSU4(256,128,512)
509
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
510
+
511
+ self.stage5 = RSU4F(512,256,512)
512
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
513
+
514
+ self.stage6 = RSU4F(512,256,512)
515
+
516
+ # decoder
517
+ self.stage5d = RSU4F(1024,256,512)
518
+ self.stage4d = RSU4(1024,128,256)
519
+ self.stage3d = RSU5(512,64,128)
520
+ self.stage2d = RSU6(256,32,64)
521
+ self.stage1d = RSU7(128,16,64)
522
+
523
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
524
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
525
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
526
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
527
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
528
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
529
+
530
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
531
+
532
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
533
+
534
+ # return muti_loss_fusion(preds,targets)
535
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
536
+
537
+ def compute_loss(self, preds, targets):
538
+
539
+ # return muti_loss_fusion(preds,targets)
540
+ return muti_loss_fusion(preds, targets)
541
+
542
+ def forward(self,x):
543
+
544
+ hx = x
545
+
546
+ hxin = self.conv_in(hx)
547
+ #hx = self.pool_in(hxin)
548
+
549
+ #stage 1
550
+ hx1 = self.stage1(hxin)
551
+ hx = self.pool12(hx1)
552
+
553
+ #stage 2
554
+ hx2 = self.stage2(hx)
555
+ hx = self.pool23(hx2)
556
+
557
+ #stage 3
558
+ hx3 = self.stage3(hx)
559
+ hx = self.pool34(hx3)
560
+
561
+ #stage 4
562
+ hx4 = self.stage4(hx)
563
+ hx = self.pool45(hx4)
564
+
565
+ #stage 5
566
+ hx5 = self.stage5(hx)
567
+ hx = self.pool56(hx5)
568
+
569
+ #stage 6
570
+ hx6 = self.stage6(hx)
571
+ hx6up = _upsample_like(hx6,hx5)
572
+
573
+ #-------------------- decoder --------------------
574
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
575
+ hx5dup = _upsample_like(hx5d,hx4)
576
+
577
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
578
+ hx4dup = _upsample_like(hx4d,hx3)
579
+
580
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
581
+ hx3dup = _upsample_like(hx3d,hx2)
582
+
583
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
584
+ hx2dup = _upsample_like(hx2d,hx1)
585
+
586
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
587
+
588
+
589
+ #side output
590
+ d1 = self.side1(hx1d)
591
+ d1 = _upsample_like(d1,x)
592
+
593
+ d2 = self.side2(hx2d)
594
+ d2 = _upsample_like(d2,x)
595
+
596
+ d3 = self.side3(hx3d)
597
+ d3 = _upsample_like(d3,x)
598
+
599
+ d4 = self.side4(hx4d)
600
+ d4 = _upsample_like(d4,x)
601
+
602
+ d5 = self.side5(hx5d)
603
+ d5 = _upsample_like(d5,x)
604
+
605
+ d6 = self.side6(hx6)
606
+ d6 = _upsample_like(d6,x)
607
+
608
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
609
+
610
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]