Feature Extraction
Transformers
PyTorch
bbsnet
custom_code
thinh-huynh-re commited on
Commit
e674a89
·
1 Parent(s): b4d5d67

Upload model

Browse files
Files changed (3) hide show
  1. BBSNet_model.py +458 -0
  2. ResNet.py +156 -0
  3. modeling_bbsnet.py +1 -2
BBSNet_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+
5
+ from .ResNet import ResNet50
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1):
9
+ "3x3 convolution with padding"
10
+ return nn.Conv2d(
11
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
12
+ )
13
+
14
+
15
+ class TransBasicBlock(nn.Module):
16
+ expansion = 1
17
+
18
+ def __init__(self, inplanes, planes, stride=1, upsample=None, **kwargs):
19
+ super(TransBasicBlock, self).__init__()
20
+ self.conv1 = conv3x3(inplanes, inplanes)
21
+ self.bn1 = nn.BatchNorm2d(inplanes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+ if upsample is not None and stride != 1:
24
+ self.conv2 = nn.ConvTranspose2d(
25
+ inplanes,
26
+ planes,
27
+ kernel_size=3,
28
+ stride=stride,
29
+ padding=1,
30
+ output_padding=1,
31
+ bias=False,
32
+ )
33
+ else:
34
+ self.conv2 = conv3x3(inplanes, planes, stride)
35
+ self.bn2 = nn.BatchNorm2d(planes)
36
+ self.upsample = upsample
37
+ self.stride = stride
38
+
39
+ def forward(self, x):
40
+ residual = x
41
+
42
+ out = self.conv1(x)
43
+ out = self.bn1(out)
44
+ out = self.relu(out)
45
+
46
+ out = self.conv2(out)
47
+ out = self.bn2(out)
48
+
49
+ if self.upsample is not None:
50
+ residual = self.upsample(x)
51
+
52
+ out += residual
53
+ out = self.relu(out)
54
+
55
+ return out
56
+
57
+
58
+ class ChannelAttention(nn.Module):
59
+ def __init__(self, in_planes, ratio=16):
60
+ super(ChannelAttention, self).__init__()
61
+
62
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
63
+
64
+ self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
65
+ self.relu1 = nn.ReLU()
66
+ self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
67
+
68
+ self.sigmoid = nn.Sigmoid()
69
+
70
+ def forward(self, x):
71
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
72
+ out = max_out
73
+ return self.sigmoid(out)
74
+
75
+
76
+ class SpatialAttention(nn.Module):
77
+ def __init__(self, kernel_size=7):
78
+ super(SpatialAttention, self).__init__()
79
+
80
+ assert kernel_size in (3, 7), "kernel size must be 3 or 7"
81
+ padding = 3 if kernel_size == 7 else 1
82
+
83
+ self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
84
+ self.sigmoid = nn.Sigmoid()
85
+
86
+ def forward(self, x):
87
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
88
+ x = max_out
89
+ x = self.conv1(x)
90
+ return self.sigmoid(x)
91
+
92
+
93
+ class BasicConv2d(nn.Module):
94
+ def __init__(
95
+ self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1
96
+ ):
97
+ super(BasicConv2d, self).__init__()
98
+ self.conv = nn.Conv2d(
99
+ in_planes,
100
+ out_planes,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ dilation=dilation,
105
+ bias=False,
106
+ )
107
+ self.bn = nn.BatchNorm2d(out_planes)
108
+ self.relu = nn.ReLU(inplace=True)
109
+
110
+ def forward(self, x):
111
+ x = self.conv(x)
112
+ x = self.bn(x)
113
+ return x
114
+
115
+
116
+ # Global Contextual module
117
+ class GCM(nn.Module):
118
+ def __init__(self, in_channel, out_channel):
119
+ super(GCM, self).__init__()
120
+ self.relu = nn.ReLU(True)
121
+ self.branch0 = nn.Sequential(
122
+ BasicConv2d(in_channel, out_channel, 1),
123
+ )
124
+ self.branch1 = nn.Sequential(
125
+ BasicConv2d(in_channel, out_channel, 1),
126
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
127
+ BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
128
+ BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
129
+ )
130
+ self.branch2 = nn.Sequential(
131
+ BasicConv2d(in_channel, out_channel, 1),
132
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
133
+ BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
134
+ BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
135
+ )
136
+ self.branch3 = nn.Sequential(
137
+ BasicConv2d(in_channel, out_channel, 1),
138
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
139
+ BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
140
+ BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
141
+ )
142
+ self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
143
+ self.conv_res = BasicConv2d(in_channel, out_channel, 1)
144
+
145
+ def forward(self, x):
146
+ x0 = self.branch0(x)
147
+ x1 = self.branch1(x)
148
+ x2 = self.branch2(x)
149
+ x3 = self.branch3(x)
150
+
151
+ x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
152
+
153
+ x = self.relu(x_cat + self.conv_res(x))
154
+ return x
155
+
156
+
157
+ # aggregation of the high-level(teacher) features
158
+ class aggregation_init(nn.Module):
159
+ def __init__(self, channel):
160
+ super(aggregation_init, self).__init__()
161
+ self.relu = nn.ReLU(True)
162
+
163
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
164
+ self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
165
+ self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
166
+ self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
167
+ self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
168
+ self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
169
+
170
+ self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
171
+ self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
172
+ self.conv4 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
173
+ self.conv5 = nn.Conv2d(3 * channel, 1, 1)
174
+
175
+ def forward(self, x1, x2, x3):
176
+ x1_1 = x1
177
+ x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
178
+ x3_1 = (
179
+ self.conv_upsample2(self.upsample(self.upsample(x1)))
180
+ * self.conv_upsample3(self.upsample(x2))
181
+ * x3
182
+ )
183
+
184
+ x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
185
+ x2_2 = self.conv_concat2(x2_2)
186
+
187
+ x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
188
+ x3_2 = self.conv_concat3(x3_2)
189
+
190
+ x = self.conv4(x3_2)
191
+ x = self.conv5(x)
192
+
193
+ return x
194
+
195
+
196
+ # aggregation of the low-level(student) features
197
+ class aggregation_final(nn.Module):
198
+ def __init__(self, channel):
199
+ super(aggregation_final, self).__init__()
200
+ self.relu = nn.ReLU(True)
201
+
202
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
203
+ self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
204
+ self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
205
+ self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
206
+ self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
207
+ self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
208
+
209
+ self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
210
+ self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
211
+
212
+ def forward(self, x1, x2, x3):
213
+ x1_1 = x1
214
+ x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
215
+ x3_1 = self.conv_upsample2(self.upsample(x1)) * self.conv_upsample3(x2) * x3
216
+
217
+ x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
218
+ x2_2 = self.conv_concat2(x2_2)
219
+
220
+ x3_2 = torch.cat((x3_1, self.conv_upsample5(x2_2)), 1)
221
+ x3_2 = self.conv_concat3(x3_2)
222
+
223
+ return x3_2
224
+
225
+
226
+ # Refinement flow
227
+ class Refine(nn.Module):
228
+ def __init__(self):
229
+ super(Refine, self).__init__()
230
+ self.upsample2 = nn.Upsample(
231
+ scale_factor=2, mode="bilinear", align_corners=True
232
+ )
233
+
234
+ def forward(self, attention, x1, x2, x3):
235
+ # Note that there is an error in the manuscript. In the paper, the refinement strategy is depicted as ""f'=f*S1"", it should be ""f'=f+f*S1"".
236
+ x1 = x1 + torch.mul(x1, self.upsample2(attention))
237
+ x2 = x2 + torch.mul(x2, self.upsample2(attention))
238
+ x3 = x3 + torch.mul(x3, attention)
239
+
240
+ return x1, x2, x3
241
+
242
+
243
+ # BBSNet
244
+ class BBSNet(nn.Module):
245
+ def __init__(self, channel=32):
246
+ super(BBSNet, self).__init__()
247
+
248
+ # Backbone model
249
+ self.resnet = ResNet50("rgb")
250
+ self.resnet_depth = ResNet50("rgbd")
251
+
252
+ # Decoder 1
253
+ self.rfb2_1 = GCM(512, channel)
254
+ self.rfb3_1 = GCM(1024, channel)
255
+ self.rfb4_1 = GCM(2048, channel)
256
+ self.agg1 = aggregation_init(channel)
257
+
258
+ # Decoder 2
259
+ self.rfb0_2 = GCM(64, channel)
260
+ self.rfb1_2 = GCM(256, channel)
261
+ self.rfb5_2 = GCM(512, channel)
262
+ self.agg2 = aggregation_final(channel)
263
+
264
+ # upsample function
265
+ self.upsample = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
266
+ self.upsample4 = nn.Upsample(
267
+ scale_factor=4, mode="bilinear", align_corners=True
268
+ )
269
+ self.upsample2 = nn.Upsample(
270
+ scale_factor=2, mode="bilinear", align_corners=True
271
+ )
272
+
273
+ # Refinement flow
274
+ self.HA = Refine()
275
+
276
+ # Components of DEM module
277
+ self.atten_depth_channel_0 = ChannelAttention(64)
278
+ self.atten_depth_channel_1 = ChannelAttention(256)
279
+ self.atten_depth_channel_2 = ChannelAttention(512)
280
+ self.atten_depth_channel_3_1 = ChannelAttention(1024)
281
+ self.atten_depth_channel_4_1 = ChannelAttention(2048)
282
+
283
+ self.atten_depth_spatial_0 = SpatialAttention()
284
+ self.atten_depth_spatial_1 = SpatialAttention()
285
+ self.atten_depth_spatial_2 = SpatialAttention()
286
+ self.atten_depth_spatial_3_1 = SpatialAttention()
287
+ self.atten_depth_spatial_4_1 = SpatialAttention()
288
+
289
+ # Components of PTM module
290
+ self.inplanes = 32 * 2
291
+ self.deconv1 = self._make_transpose(TransBasicBlock, 32 * 2, 3, stride=2)
292
+ self.inplanes = 32
293
+ self.deconv2 = self._make_transpose(TransBasicBlock, 32, 3, stride=2)
294
+ self.agant1 = self._make_agant_layer(32 * 3, 32 * 2)
295
+ self.agant2 = self._make_agant_layer(32 * 2, 32)
296
+ self.out0_conv = nn.Conv2d(32 * 3, 1, kernel_size=1, stride=1, bias=True)
297
+ self.out1_conv = nn.Conv2d(32 * 2, 1, kernel_size=1, stride=1, bias=True)
298
+ self.out2_conv = nn.Conv2d(32 * 1, 1, kernel_size=1, stride=1, bias=True)
299
+
300
+ if self.training:
301
+ self.initialize_weights()
302
+
303
+ def forward(self, x, x_depth):
304
+ x = self.resnet.conv1(x)
305
+ x = self.resnet.bn1(x)
306
+ x = self.resnet.relu(x)
307
+ x = self.resnet.maxpool(x)
308
+
309
+ x_depth = self.resnet_depth.conv1(x_depth)
310
+ x_depth = self.resnet_depth.bn1(x_depth)
311
+ x_depth = self.resnet_depth.relu(x_depth)
312
+ x_depth = self.resnet_depth.maxpool(x_depth)
313
+
314
+ # layer0 merge
315
+ temp = x_depth.mul(self.atten_depth_channel_0(x_depth))
316
+ temp = temp.mul(self.atten_depth_spatial_0(temp))
317
+ x = x + temp
318
+ # layer0 merge end
319
+
320
+ x1 = self.resnet.layer1(x) # 256 x 64 x 64
321
+ x1_depth = self.resnet_depth.layer1(x_depth)
322
+
323
+ # layer1 merge
324
+ temp = x1_depth.mul(self.atten_depth_channel_1(x1_depth))
325
+ temp = temp.mul(self.atten_depth_spatial_1(temp))
326
+ x1 = x1 + temp
327
+ # layer1 merge end
328
+
329
+ x2 = self.resnet.layer2(x1) # 512 x 32 x 32
330
+ x2_depth = self.resnet_depth.layer2(x1_depth)
331
+
332
+ # layer2 merge
333
+ temp = x2_depth.mul(self.atten_depth_channel_2(x2_depth))
334
+ temp = temp.mul(self.atten_depth_spatial_2(temp))
335
+ x2 = x2 + temp
336
+ # layer2 merge end
337
+
338
+ x2_1 = x2
339
+
340
+ x3_1 = self.resnet.layer3_1(x2_1) # 1024 x 16 x 16
341
+ x3_1_depth = self.resnet_depth.layer3_1(x2_depth)
342
+
343
+ # layer3_1 merge
344
+ temp = x3_1_depth.mul(self.atten_depth_channel_3_1(x3_1_depth))
345
+ temp = temp.mul(self.atten_depth_spatial_3_1(temp))
346
+ x3_1 = x3_1 + temp
347
+ # layer3_1 merge end
348
+
349
+ x4_1 = self.resnet.layer4_1(x3_1) # 2048 x 8 x 8
350
+ x4_1_depth = self.resnet_depth.layer4_1(x3_1_depth)
351
+
352
+ # layer4_1 merge
353
+ temp = x4_1_depth.mul(self.atten_depth_channel_4_1(x4_1_depth))
354
+ temp = temp.mul(self.atten_depth_spatial_4_1(temp))
355
+ x4_1 = x4_1 + temp
356
+ # layer4_1 merge end
357
+
358
+ # produce initial saliency map by decoder1
359
+ x2_1 = self.rfb2_1(x2_1)
360
+ x3_1 = self.rfb3_1(x3_1)
361
+ x4_1 = self.rfb4_1(x4_1)
362
+ attention_map = self.agg1(x4_1, x3_1, x2_1)
363
+
364
+ # Refine low-layer features by initial map
365
+ x, x1, x5 = self.HA(attention_map.sigmoid(), x, x1, x2)
366
+
367
+ # produce final saliency map by decoder2
368
+ x0_2 = self.rfb0_2(x)
369
+ x1_2 = self.rfb1_2(x1)
370
+ x5_2 = self.rfb5_2(x5)
371
+ y = self.agg2(x5_2, x1_2, x0_2) # *4
372
+
373
+ # PTM module
374
+ y = self.agant1(y)
375
+ y = self.deconv1(y)
376
+ y = self.agant2(y)
377
+ y = self.deconv2(y)
378
+ y = self.out2_conv(y)
379
+
380
+ return self.upsample(attention_map), y
381
+
382
+ def _make_agant_layer(self, inplanes, planes):
383
+ layers = nn.Sequential(
384
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False),
385
+ nn.BatchNorm2d(planes),
386
+ nn.ReLU(inplace=True),
387
+ )
388
+ return layers
389
+
390
+ def _make_transpose(self, block, planes, blocks, stride=1):
391
+ upsample = None
392
+ if stride != 1:
393
+ upsample = nn.Sequential(
394
+ nn.ConvTranspose2d(
395
+ self.inplanes,
396
+ planes,
397
+ kernel_size=2,
398
+ stride=stride,
399
+ padding=0,
400
+ bias=False,
401
+ ),
402
+ nn.BatchNorm2d(planes),
403
+ )
404
+ elif self.inplanes != planes:
405
+ upsample = nn.Sequential(
406
+ nn.Conv2d(
407
+ self.inplanes, planes, kernel_size=1, stride=stride, bias=False
408
+ ),
409
+ nn.BatchNorm2d(planes),
410
+ )
411
+
412
+ layers = []
413
+
414
+ for i in range(1, blocks):
415
+ layers.append(block(self.inplanes, self.inplanes))
416
+
417
+ layers.append(block(self.inplanes, planes, stride, upsample))
418
+ self.inplanes = planes
419
+
420
+ return nn.Sequential(*layers)
421
+
422
+ # initialize the weights
423
+ def initialize_weights(self):
424
+ res50 = torchvision.models.resnet50(pretrained=True)
425
+ pretrained_dict = res50.state_dict()
426
+ all_params = {}
427
+ for k, v in self.resnet.state_dict().items():
428
+ if k in pretrained_dict.keys():
429
+ v = pretrained_dict[k]
430
+ all_params[k] = v
431
+ elif "_1" in k:
432
+ name = k.split("_1")[0] + k.split("_1")[1]
433
+ v = pretrained_dict[name]
434
+ all_params[k] = v
435
+ elif "_2" in k:
436
+ name = k.split("_2")[0] + k.split("_2")[1]
437
+ v = pretrained_dict[name]
438
+ all_params[k] = v
439
+ assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
440
+ self.resnet.load_state_dict(all_params)
441
+
442
+ all_params = {}
443
+ for k, v in self.resnet_depth.state_dict().items():
444
+ if k == "conv1.weight":
445
+ all_params[k] = torch.nn.init.normal_(v, mean=0, std=1)
446
+ elif k in pretrained_dict.keys():
447
+ v = pretrained_dict[k]
448
+ all_params[k] = v
449
+ elif "_1" in k:
450
+ name = k.split("_1")[0] + k.split("_1")[1]
451
+ v = pretrained_dict[name]
452
+ all_params[k] = v
453
+ elif "_2" in k:
454
+ name = k.split("_2")[0] + k.split("_2")[1]
455
+ v = pretrained_dict[name]
456
+ all_params[k] = v
457
+ assert len(all_params.keys()) == len(self.resnet_depth.state_dict().keys())
458
+ self.resnet_depth.load_state_dict(all_params)
ResNet.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+
4
+
5
+ def conv3x3(in_planes, out_planes, stride=1):
6
+ """3x3 convolution with padding"""
7
+ return nn.Conv2d(
8
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
9
+ )
10
+
11
+
12
+ class BasicBlock(nn.Module):
13
+ expansion = 1
14
+
15
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
16
+ super(BasicBlock, self).__init__()
17
+ self.conv1 = conv3x3(inplanes, planes, stride)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.conv2 = conv3x3(planes, planes)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.downsample = downsample
23
+ self.stride = stride
24
+
25
+ def forward(self, x):
26
+ residual = x
27
+
28
+ out = self.conv1(x)
29
+ out = self.bn1(out)
30
+ out = self.relu(out)
31
+
32
+ out = self.conv2(out)
33
+ out = self.bn2(out)
34
+
35
+ if self.downsample is not None:
36
+ residual = self.downsample(x)
37
+
38
+ out += residual
39
+ out = self.relu(out)
40
+
41
+ return out
42
+
43
+
44
+ class Bottleneck(nn.Module):
45
+ expansion = 4
46
+
47
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
48
+ super(Bottleneck, self).__init__()
49
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(planes)
51
+ self.conv2 = nn.Conv2d(
52
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
53
+ )
54
+ self.bn2 = nn.BatchNorm2d(planes)
55
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56
+ self.bn3 = nn.BatchNorm2d(planes * 4)
57
+ self.relu = nn.ReLU(inplace=True)
58
+ self.downsample = downsample
59
+ self.stride = stride
60
+
61
+ def forward(self, x):
62
+ residual = x
63
+
64
+ out = self.conv1(x)
65
+ out = self.bn1(out)
66
+ out = self.relu(out)
67
+
68
+ out = self.conv2(out)
69
+ out = self.bn2(out)
70
+ out = self.relu(out)
71
+
72
+ out = self.conv3(out)
73
+ out = self.bn3(out)
74
+
75
+ if self.downsample is not None:
76
+ residual = self.downsample(x)
77
+
78
+ out += residual
79
+ out = self.relu(out)
80
+
81
+ return out
82
+
83
+
84
+ class ResNet50(nn.Module):
85
+ def __init__(self, mode="rgb"):
86
+ self.inplanes = 64
87
+ super(ResNet50, self).__init__()
88
+ if mode == "rgb":
89
+ self.conv1 = nn.Conv2d(
90
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False
91
+ )
92
+ elif mode == "rgbd":
93
+ self.conv1 = nn.Conv2d(
94
+ 1, 64, kernel_size=7, stride=2, padding=3, bias=False
95
+ )
96
+ elif mode == "share":
97
+ self.conv1 = nn.Conv2d(
98
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False
99
+ )
100
+ self.conv1_d = nn.Conv2d(
101
+ 1, 64, kernel_size=7, stride=2, padding=3, bias=False
102
+ )
103
+ else:
104
+ raise
105
+ self.bn1 = nn.BatchNorm2d(64)
106
+ self.relu = nn.ReLU(inplace=True)
107
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108
+ self.layer1 = self._make_layer(Bottleneck, 64, 3)
109
+ self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
110
+ self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2)
111
+ self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2)
112
+
113
+ self.inplanes = 512
114
+
115
+ for m in self.modules():
116
+ if isinstance(m, nn.Conv2d):
117
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
119
+ elif isinstance(m, nn.BatchNorm2d):
120
+ m.weight.data.fill_(1)
121
+ m.bias.data.zero_()
122
+
123
+ def _make_layer(self, block, planes, blocks, stride=1):
124
+ downsample = None
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ nn.Conv2d(
128
+ self.inplanes,
129
+ planes * block.expansion,
130
+ kernel_size=1,
131
+ stride=stride,
132
+ bias=False,
133
+ ),
134
+ nn.BatchNorm2d(planes * block.expansion),
135
+ )
136
+
137
+ layers = []
138
+ layers.append(block(self.inplanes, planes, stride, downsample))
139
+ self.inplanes = planes * block.expansion
140
+ for i in range(1, blocks):
141
+ layers.append(block(self.inplanes, planes))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ x = self.conv1(x)
147
+ x = self.bn1(x)
148
+ x = self.relu(x)
149
+ x = self.maxpool(x)
150
+
151
+ x = self.layer1(x)
152
+ x = self.layer2(x)
153
+ x1 = self.layer3_1(x)
154
+ x1 = self.layer4_1(x1)
155
+
156
+ return x1, x1
modeling_bbsnet.py CHANGED
@@ -3,9 +3,8 @@ from typing import Dict, Optional
3
  from torch import Tensor, nn
4
  from transformers import PreTrainedModel
5
 
6
- from models.BBSNet_model import BBSNet
7
-
8
  from .configuration_bbsnet import BBSNetConfig
 
9
 
10
 
11
  class BBSNetModel(PreTrainedModel):
 
3
  from torch import Tensor, nn
4
  from transformers import PreTrainedModel
5
 
 
 
6
  from .configuration_bbsnet import BBSNetConfig
7
+ from .BBSNet_model import BBSNet
8
 
9
 
10
  class BBSNetModel(PreTrainedModel):