pythn commited on
Commit
4a1f918
·
verified ·
1 Parent(s): 036a92d

Upload with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +14 -0
  3. AllinonSAM/LICENSE +21 -0
  4. AllinonSAM/README.md +0 -0
  5. AllinonSAM/__pycache__/axialnet.cpython-38.pyc +0 -0
  6. AllinonSAM/__pycache__/baselines.cpython-38.pyc +0 -0
  7. AllinonSAM/__pycache__/combined_model.cpython-38.pyc +0 -0
  8. AllinonSAM/__pycache__/data_utils.cpython-312.pyc +0 -0
  9. AllinonSAM/__pycache__/data_utils.cpython-38.pyc +0 -0
  10. AllinonSAM/__pycache__/model.cpython-312.pyc +0 -0
  11. AllinonSAM/__pycache__/model.cpython-38.pyc +0 -0
  12. AllinonSAM/__pycache__/test.cpython-312.pyc +0 -0
  13. AllinonSAM/__pycache__/test.cpython-38.pyc +0 -0
  14. AllinonSAM/__pycache__/train.cpython-312.pyc +0 -0
  15. AllinonSAM/__pycache__/train.cpython-38.pyc +0 -0
  16. AllinonSAM/__pycache__/utils.cpython-312.pyc +0 -0
  17. AllinonSAM/__pycache__/utils.cpython-38.pyc +0 -0
  18. AllinonSAM/__pycache__/vit_seg_configs.cpython-38.pyc +0 -0
  19. AllinonSAM/__pycache__/vit_seg_modeling.cpython-38.pyc +0 -0
  20. AllinonSAM/__pycache__/vit_seg_modeling_resnet_skip.cpython-38.pyc +0 -0
  21. AllinonSAM/axialnet.py +730 -0
  22. AllinonSAM/baselines.py +630 -0
  23. AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_0_img_0.png +0 -0
  24. AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_0_img_1.png +0 -0
  25. AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_1_img_0.png +0 -0
  26. AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_1_img_1.png +0 -0
  27. AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_0_img_0.png +0 -0
  28. AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_0_img_1.png +0 -0
  29. AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_1_img_0.png +0 -0
  30. AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_1_img_1.png +0 -0
  31. AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_0_img_0.png +0 -0
  32. AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_0_img_1.png +0 -0
  33. AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_1_img_0.png +0 -0
  34. AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_1_img_1.png +0 -0
  35. AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_0_img_0.png +0 -0
  36. AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_0_img_1.png +0 -0
  37. AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_1_img_0.png +0 -0
  38. AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_1_img_1.png +0 -0
  39. AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_0_img_0.png +0 -0
  40. AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_0_img_1.png +0 -0
  41. AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_1_img_0.png +0 -0
  42. AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_1_img_1.png +0 -0
  43. AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_0_img_0.png +0 -0
  44. AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_0_img_1.png +0 -0
  45. AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_1_img_0.png +0 -0
  46. AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_1_img_1.png +0 -0
  47. AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_0_img_0.png +0 -0
  48. AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_0_img_1.png +0 -0
  49. AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_1_img_0.png +0 -0
  50. AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_1_img_1.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ AllinonSAM/eval/lits/output_demo.nii filter=lfs diff=lfs merge=lfs -text
37
+ AllinonSAM/wandb/run-20241018_210810-zrrx3qz9/run-zrrx3qz9.wandb filter=lfs diff=lfs merge=lfs -text
38
+ AllinonSAM/wandb/run-20241018_162125-i4stmvih/run-i4stmvih.wandb filter=lfs diff=lfs merge=lfs -text
39
+ AllinonSAM/wandb/run-20240915_215641-1usjns7w/run-1usjns7w.wandb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.cpython-38.pyc
3
+ *.pth
4
+ *.gz
5
+ *.zip
6
+ *.png
7
+ *.jpg
8
+ *.JPG
9
+ *.tif
10
+ *.bmp
11
+ *.out
12
+ *.txt
13
+ AllinonSAM/wandb/
14
+ __pycache__
AllinonSAM/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ahmed Heakl
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
AllinonSAM/README.md ADDED
File without changes
AllinonSAM/__pycache__/axialnet.cpython-38.pyc ADDED
Binary file (17.3 kB). View file
 
AllinonSAM/__pycache__/baselines.cpython-38.pyc ADDED
Binary file (16.8 kB). View file
 
AllinonSAM/__pycache__/combined_model.cpython-38.pyc ADDED
Binary file (937 Bytes). View file
 
AllinonSAM/__pycache__/data_utils.cpython-312.pyc ADDED
Binary file (81.1 kB). View file
 
AllinonSAM/__pycache__/data_utils.cpython-38.pyc ADDED
Binary file (40.9 kB). View file
 
AllinonSAM/__pycache__/model.cpython-312.pyc ADDED
Binary file (15.6 kB). View file
 
AllinonSAM/__pycache__/model.cpython-38.pyc ADDED
Binary file (8.11 kB). View file
 
AllinonSAM/__pycache__/test.cpython-312.pyc ADDED
Binary file (3.98 kB). View file
 
AllinonSAM/__pycache__/test.cpython-38.pyc ADDED
Binary file (1.89 kB). View file
 
AllinonSAM/__pycache__/train.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
AllinonSAM/__pycache__/train.cpython-38.pyc ADDED
Binary file (8.91 kB). View file
 
AllinonSAM/__pycache__/utils.cpython-312.pyc ADDED
Binary file (6.85 kB). View file
 
AllinonSAM/__pycache__/utils.cpython-38.pyc ADDED
Binary file (12.3 kB). View file
 
AllinonSAM/__pycache__/vit_seg_configs.cpython-38.pyc ADDED
Binary file (3.34 kB). View file
 
AllinonSAM/__pycache__/vit_seg_modeling.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
 
AllinonSAM/__pycache__/vit_seg_modeling_resnet_skip.cpython-38.pyc ADDED
Binary file (5.88 kB). View file
 
AllinonSAM/axialnet.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from utils import *
7
+ import pdb
8
+ import matplotlib.pyplot as plt
9
+
10
+ import random
11
+
12
+
13
+
14
+ def conv1x1(in_planes, out_planes, stride=1):
15
+ """1x1 convolution"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
17
+
18
+
19
+ class AxialAttention(nn.Module):
20
+ def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
21
+ stride=1, bias=False, width=False):
22
+ assert (in_planes % groups == 0) and (out_planes % groups == 0)
23
+ super(AxialAttention, self).__init__()
24
+ self.in_planes = in_planes
25
+ self.out_planes = out_planes
26
+ self.groups = groups
27
+ self.group_planes = out_planes // groups
28
+ self.kernel_size = kernel_size
29
+ self.stride = stride
30
+ self.bias = bias
31
+ self.width = width
32
+
33
+ # Multi-head self attention
34
+ self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
35
+ padding=0, bias=False)
36
+ self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
37
+ self.bn_similarity = nn.BatchNorm2d(groups * 3)
38
+
39
+ self.bn_output = nn.BatchNorm1d(out_planes * 2)
40
+
41
+ # Position embedding
42
+ self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
43
+ query_index = torch.arange(kernel_size).unsqueeze(0)
44
+ key_index = torch.arange(kernel_size).unsqueeze(1)
45
+ relative_index = key_index - query_index + kernel_size - 1
46
+ self.register_buffer('flatten_index', relative_index.view(-1))
47
+ if stride > 1:
48
+ self.pooling = nn.AvgPool2d(stride, stride=stride)
49
+
50
+ self.reset_parameters()
51
+
52
+ def forward(self, x):
53
+ # pdb.set_trace()
54
+ if self.width:
55
+ x = x.permute(0, 2, 1, 3)
56
+ else:
57
+ x = x.permute(0, 3, 1, 2) # N, W, C, H
58
+ N, W, C, H = x.shape
59
+ x = x.contiguous().view(N * W, C, H)
60
+
61
+ # Transformations
62
+ qkv = self.bn_qkv(self.qkv_transform(x))
63
+ q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
64
+
65
+ # Calculate position embedding
66
+ all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
67
+ q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
68
+
69
+ qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
70
+ kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
71
+
72
+ qk = torch.einsum('bgci, bgcj->bgij', q, k)
73
+
74
+ stacked_similarity = torch.cat([qk, qr, kr], dim=1)
75
+ stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
76
+ #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
77
+ # (N, groups, H, H, W)
78
+ similarity = F.softmax(stacked_similarity, dim=3)
79
+ sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
80
+ sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
81
+ stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
82
+ output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
83
+
84
+ if self.width:
85
+ output = output.permute(0, 2, 1, 3)
86
+ else:
87
+ output = output.permute(0, 2, 3, 1)
88
+
89
+ if self.stride > 1:
90
+ output = self.pooling(output)
91
+
92
+ return output
93
+
94
+ def reset_parameters(self):
95
+ self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
96
+ #nn.init.uniform_(self.relative, -0.1, 0.1)
97
+ nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
98
+
99
+ class AxialAttention_dynamic(nn.Module):
100
+ def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
101
+ stride=1, bias=False, width=False):
102
+ assert (in_planes % groups == 0) and (out_planes % groups == 0)
103
+ super(AxialAttention_dynamic, self).__init__()
104
+ self.in_planes = in_planes
105
+ self.out_planes = out_planes
106
+ self.groups = groups
107
+ self.group_planes = out_planes // groups
108
+ self.kernel_size = kernel_size
109
+ self.stride = stride
110
+ self.bias = bias
111
+ self.width = width
112
+
113
+ # Multi-head self attention
114
+ self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
115
+ padding=0, bias=False)
116
+ self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
117
+ self.bn_similarity = nn.BatchNorm2d(groups * 3)
118
+ self.bn_output = nn.BatchNorm1d(out_planes * 2)
119
+
120
+ # Priority on encoding
121
+
122
+ ## Initial values
123
+
124
+ self.f_qr = nn.Parameter(torch.tensor(0.1), requires_grad=False)
125
+ self.f_kr = nn.Parameter(torch.tensor(0.1), requires_grad=False)
126
+ self.f_sve = nn.Parameter(torch.tensor(0.1), requires_grad=False)
127
+ self.f_sv = nn.Parameter(torch.tensor(1.0), requires_grad=False)
128
+
129
+
130
+ # Position embedding
131
+ self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
132
+ query_index = torch.arange(kernel_size).unsqueeze(0)
133
+ key_index = torch.arange(kernel_size).unsqueeze(1)
134
+ relative_index = key_index - query_index + kernel_size - 1
135
+ self.register_buffer('flatten_index', relative_index.view(-1))
136
+ if stride > 1:
137
+ self.pooling = nn.AvgPool2d(stride, stride=stride)
138
+
139
+ self.reset_parameters()
140
+ # self.print_para()
141
+
142
+ def forward(self, x):
143
+ if self.width:
144
+ x = x.permute(0, 2, 1, 3)
145
+ else:
146
+ x = x.permute(0, 3, 1, 2) # N, W, C, H
147
+ N, W, C, H = x.shape
148
+ x = x.contiguous().view(N * W, C, H)
149
+
150
+ # Transformations
151
+ qkv = self.bn_qkv(self.qkv_transform(x))
152
+ q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
153
+
154
+ # Calculate position embedding
155
+ all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
156
+ q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
157
+ qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
158
+ kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
159
+ qk = torch.einsum('bgci, bgcj->bgij', q, k)
160
+
161
+
162
+ # multiply by factors
163
+ qr = torch.mul(qr, self.f_qr)
164
+ kr = torch.mul(kr, self.f_kr)
165
+
166
+ stacked_similarity = torch.cat([qk, qr, kr], dim=1)
167
+ stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
168
+ #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
169
+ # (N, groups, H, H, W)
170
+ similarity = F.softmax(stacked_similarity, dim=3)
171
+ sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
172
+ sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
173
+
174
+ # multiply by factors
175
+ sv = torch.mul(sv, self.f_sv)
176
+ sve = torch.mul(sve, self.f_sve)
177
+
178
+ stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
179
+ output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
180
+
181
+ if self.width:
182
+ output = output.permute(0, 2, 1, 3)
183
+ else:
184
+ output = output.permute(0, 2, 3, 1)
185
+
186
+ if self.stride > 1:
187
+ output = self.pooling(output)
188
+
189
+ return output
190
+ def reset_parameters(self):
191
+ self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
192
+ #nn.init.uniform_(self.relative, -0.1, 0.1)
193
+ nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
194
+
195
+ class AxialAttention_wopos(nn.Module):
196
+ def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
197
+ stride=1, bias=False, width=False):
198
+ assert (in_planes % groups == 0) and (out_planes % groups == 0)
199
+ super(AxialAttention_wopos, self).__init__()
200
+ self.in_planes = in_planes
201
+ self.out_planes = out_planes
202
+ self.groups = groups
203
+ self.group_planes = out_planes // groups
204
+ self.kernel_size = kernel_size
205
+ self.stride = stride
206
+ self.bias = bias
207
+ self.width = width
208
+
209
+ # Multi-head self attention
210
+ self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
211
+ padding=0, bias=False)
212
+ self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
213
+ self.bn_similarity = nn.BatchNorm2d(groups )
214
+
215
+ self.bn_output = nn.BatchNorm1d(out_planes * 1)
216
+
217
+ if stride > 1:
218
+ self.pooling = nn.AvgPool2d(stride, stride=stride)
219
+
220
+ self.reset_parameters()
221
+
222
+ def forward(self, x):
223
+ if self.width:
224
+ x = x.permute(0, 2, 1, 3)
225
+ else:
226
+ x = x.permute(0, 3, 1, 2) # N, W, C, H
227
+ N, W, C, H = x.shape
228
+ x = x.contiguous().view(N * W, C, H)
229
+
230
+ # Transformations
231
+ qkv = self.bn_qkv(self.qkv_transform(x))
232
+ q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
233
+
234
+ qk = torch.einsum('bgci, bgcj->bgij', q, k)
235
+
236
+ stacked_similarity = self.bn_similarity(qk).reshape(N * W, 1, self.groups, H, H).sum(dim=1).contiguous()
237
+
238
+ similarity = F.softmax(stacked_similarity, dim=3)
239
+ sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
240
+
241
+ sv = sv.reshape(N*W,self.out_planes * 1, H).contiguous()
242
+ output = self.bn_output(sv).reshape(N, W, self.out_planes, 1, H).sum(dim=-2).contiguous()
243
+
244
+
245
+ if self.width:
246
+ output = output.permute(0, 2, 1, 3)
247
+ else:
248
+ output = output.permute(0, 2, 3, 1)
249
+
250
+ if self.stride > 1:
251
+ output = self.pooling(output)
252
+
253
+ return output
254
+
255
+ def reset_parameters(self):
256
+ self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
257
+ #nn.init.uniform_(self.relative, -0.1, 0.1)
258
+ # nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))
259
+
260
+ #end of attn definition
261
+
262
+ class AxialBlock(nn.Module):
263
+ expansion = 2
264
+
265
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
266
+ base_width=64, dilation=1, norm_layer=None, kernel_size=56):
267
+ super(AxialBlock, self).__init__()
268
+ if norm_layer is None:
269
+ norm_layer = nn.BatchNorm2d
270
+ width = int(planes * (base_width / 64.))
271
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
272
+ self.conv_down = conv1x1(inplanes, width)
273
+ self.bn1 = norm_layer(width)
274
+ self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size)
275
+ self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
276
+ self.conv_up = conv1x1(width, planes * self.expansion)
277
+ self.bn2 = norm_layer(planes * self.expansion)
278
+ self.relu = nn.ReLU(inplace=True)
279
+ self.downsample = downsample
280
+ self.stride = stride
281
+
282
+ def forward(self, x):
283
+ identity = x
284
+
285
+ out = self.conv_down(x)
286
+ out = self.bn1(out)
287
+ out = self.relu(out)
288
+ # print(out.shape)
289
+ out = self.hight_block(out)
290
+ out = self.width_block(out)
291
+ out = self.relu(out)
292
+
293
+ out = self.conv_up(out)
294
+ out = self.bn2(out)
295
+
296
+ if self.downsample is not None:
297
+ identity = self.downsample(x)
298
+
299
+ out += identity
300
+ out = self.relu(out)
301
+
302
+ return out
303
+
304
+ class AxialBlock_dynamic(nn.Module):
305
+ expansion = 2
306
+
307
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
308
+ base_width=64, dilation=1, norm_layer=None, kernel_size=56):
309
+ super(AxialBlock_dynamic, self).__init__()
310
+ if norm_layer is None:
311
+ norm_layer = nn.BatchNorm2d
312
+ width = int(planes * (base_width / 64.))
313
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
314
+ self.conv_down = conv1x1(inplanes, width)
315
+ self.bn1 = norm_layer(width)
316
+ self.hight_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size)
317
+ self.width_block = AxialAttention_dynamic(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
318
+ self.conv_up = conv1x1(width, planes * self.expansion)
319
+ self.bn2 = norm_layer(planes * self.expansion)
320
+ self.relu = nn.ReLU(inplace=True)
321
+ self.downsample = downsample
322
+ self.stride = stride
323
+
324
+ def forward(self, x):
325
+ identity = x
326
+
327
+ out = self.conv_down(x)
328
+ out = self.bn1(out)
329
+ out = self.relu(out)
330
+
331
+ out = self.hight_block(out)
332
+ out = self.width_block(out)
333
+ out = self.relu(out)
334
+
335
+ out = self.conv_up(out)
336
+ out = self.bn2(out)
337
+
338
+ if self.downsample is not None:
339
+ identity = self.downsample(x)
340
+
341
+ out += identity
342
+ out = self.relu(out)
343
+
344
+ return out
345
+
346
+ class AxialBlock_wopos(nn.Module):
347
+ expansion = 2
348
+
349
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
350
+ base_width=64, dilation=1, norm_layer=None, kernel_size=56):
351
+ super(AxialBlock_wopos, self).__init__()
352
+ if norm_layer is None:
353
+ norm_layer = nn.BatchNorm2d
354
+ # print(kernel_size)
355
+ width = int(planes * (base_width / 64.))
356
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
357
+ self.conv_down = conv1x1(inplanes, width)
358
+ self.conv1 = nn.Conv2d(width, width, kernel_size = 1)
359
+ self.bn1 = norm_layer(width)
360
+ self.hight_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size)
361
+ self.width_block = AxialAttention_wopos(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True)
362
+ self.conv_up = conv1x1(width, planes * self.expansion)
363
+ self.bn2 = norm_layer(planes * self.expansion)
364
+ self.relu = nn.ReLU(inplace=True)
365
+ self.downsample = downsample
366
+ self.stride = stride
367
+
368
+ def forward(self, x):
369
+ identity = x
370
+
371
+ # pdb.set_trace()
372
+
373
+ out = self.conv_down(x)
374
+ out = self.bn1(out)
375
+ out = self.relu(out)
376
+ # print(out.shape)
377
+ out = self.hight_block(out)
378
+ out = self.width_block(out)
379
+
380
+ out = self.relu(out)
381
+
382
+ out = self.conv_up(out)
383
+ out = self.bn2(out)
384
+
385
+ if self.downsample is not None:
386
+ identity = self.downsample(x)
387
+
388
+ out += identity
389
+ out = self.relu(out)
390
+
391
+ return out
392
+
393
+
394
+ #end of block definition
395
+
396
+
397
+ class ResAxialAttentionUNet(nn.Module):
398
+
399
+ def __init__(self, block, layers, num_classes=2, zero_init_residual=True,
400
+ groups=8, width_per_group=64, replace_stride_with_dilation=None,
401
+ norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
402
+ super(ResAxialAttentionUNet, self).__init__()
403
+ if norm_layer is None:
404
+ norm_layer = nn.BatchNorm2d
405
+ self._norm_layer = norm_layer
406
+
407
+ self.inplanes = int(64 * s)
408
+ self.dilation = 1
409
+ if replace_stride_with_dilation is None:
410
+ replace_stride_with_dilation = [False, False, False]
411
+ if len(replace_stride_with_dilation) != 3:
412
+ raise ValueError("replace_stride_with_dilation should be None "
413
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
414
+ self.groups = groups
415
+ self.base_width = width_per_group
416
+ self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
417
+ bias=False)
418
+ self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
419
+ self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
420
+ self.bn1 = norm_layer(self.inplanes)
421
+ self.bn2 = norm_layer(128)
422
+ self.bn3 = norm_layer(self.inplanes)
423
+ self.relu = nn.ReLU(inplace=True)
424
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
425
+ self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
426
+ self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
427
+ dilate=replace_stride_with_dilation[0])
428
+ self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
429
+ dilate=replace_stride_with_dilation[1])
430
+ self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
431
+ dilate=replace_stride_with_dilation[2])
432
+
433
+ # Decoder
434
+ self.decoder1 = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
435
+ self.decoder2 = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
436
+ self.decoder3 = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
437
+ self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
438
+ self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
439
+ self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
440
+ self.soft = nn.Softmax(dim=1)
441
+
442
+
443
+ def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
444
+ norm_layer = self._norm_layer
445
+ downsample = None
446
+ previous_dilation = self.dilation
447
+ if dilate:
448
+ self.dilation *= stride
449
+ stride = 1
450
+ if stride != 1 or self.inplanes != planes * block.expansion:
451
+ downsample = nn.Sequential(
452
+ conv1x1(self.inplanes, planes * block.expansion, stride),
453
+ norm_layer(planes * block.expansion),
454
+ )
455
+
456
+ layers = []
457
+ layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
458
+ base_width=self.base_width, dilation=previous_dilation,
459
+ norm_layer=norm_layer, kernel_size=kernel_size))
460
+ self.inplanes = planes * block.expansion
461
+ if stride != 1:
462
+ kernel_size = kernel_size // 2
463
+
464
+ for _ in range(1, blocks):
465
+ layers.append(block(self.inplanes, planes, groups=self.groups,
466
+ base_width=self.base_width, dilation=self.dilation,
467
+ norm_layer=norm_layer, kernel_size=kernel_size))
468
+
469
+ return nn.Sequential(*layers)
470
+
471
+ def _forward_impl(self, x):
472
+
473
+ # AxialAttention Encoder
474
+ # pdb.set_trace()
475
+ x = self.conv1(x)
476
+ x = self.bn1(x)
477
+ x = self.relu(x)
478
+ x = self.conv2(x)
479
+ x = self.bn2(x)
480
+ x = self.relu(x)
481
+ x = self.conv3(x)
482
+ x = self.bn3(x)
483
+ x = self.relu(x)
484
+
485
+ x1 = self.layer1(x)
486
+
487
+ x2 = self.layer2(x1)
488
+ # print(x2.shape)
489
+ x3 = self.layer3(x2)
490
+ # print(x3.shape)
491
+ x4 = self.layer4(x3)
492
+
493
+ x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
494
+ x = torch.add(x, x4)
495
+ x = F.relu(F.interpolate(self.decoder2(x) , scale_factor=(2,2), mode ='bilinear'))
496
+ x = torch.add(x, x3)
497
+ x = F.relu(F.interpolate(self.decoder3(x) , scale_factor=(2,2), mode ='bilinear'))
498
+ x = torch.add(x, x2)
499
+ x = F.relu(F.interpolate(self.decoder4(x) , scale_factor=(2,2), mode ='bilinear'))
500
+ x = torch.add(x, x1)
501
+ x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
502
+ x = self.adjust(F.relu(x))
503
+ # pdb.set_trace()
504
+ return x
505
+
506
+ def forward(self, x):
507
+ return self._forward_impl(x)
508
+
509
+ class medt_net(nn.Module):
510
+
511
+ def __init__(self, block, block_2, layers, num_classes=2, zero_init_residual=True,
512
+ groups=8, width_per_group=64, replace_stride_with_dilation=None,
513
+ norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
514
+ super(medt_net, self).__init__()
515
+ if norm_layer is None:
516
+ norm_layer = nn.BatchNorm2d
517
+ self._norm_layer = norm_layer
518
+
519
+ self.inplanes = int(64 * s)
520
+ self.dilation = 1
521
+ if replace_stride_with_dilation is None:
522
+ replace_stride_with_dilation = [False, False, False]
523
+ if len(replace_stride_with_dilation) != 3:
524
+ raise ValueError("replace_stride_with_dilation should be None "
525
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
526
+ self.groups = groups
527
+ self.base_width = width_per_group
528
+ self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
529
+ bias=False)
530
+ self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
531
+ self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
532
+ self.bn1 = norm_layer(self.inplanes)
533
+ self.bn2 = norm_layer(128)
534
+ self.bn3 = norm_layer(self.inplanes)
535
+ # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
536
+ self.bn1 = norm_layer(self.inplanes)
537
+ self.relu = nn.ReLU(inplace=True)
538
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
539
+ self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
540
+ self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
541
+ dilate=replace_stride_with_dilation[0])
542
+ # self.layer3 = self._make_layer(block, int(512 * s), layers[2], stride=2, kernel_size=(img_size//4),
543
+ # dilate=replace_stride_with_dilation[1])
544
+ # self.layer4 = self._make_layer(block, int(1024 * s), layers[3], stride=2, kernel_size=(img_size//8),
545
+ # dilate=replace_stride_with_dilation[2])
546
+
547
+ # Decoder
548
+ # self.decoder1 = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
549
+ # self.decoder2 = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
550
+ # self.decoder3 = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
551
+ self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
552
+ self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
553
+ self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
554
+ self.soft = nn.Softmax(dim=1)
555
+
556
+
557
+ self.conv1_p = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
558
+ bias=False)
559
+ self.conv2_p = nn.Conv2d(self.inplanes,128, kernel_size=3, stride=1, padding=1,
560
+ bias=False)
561
+ self.conv3_p = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1,
562
+ bias=False)
563
+ # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
564
+ self.bn1_p = norm_layer(self.inplanes)
565
+ self.bn2_p = norm_layer(128)
566
+ self.bn3_p = norm_layer(self.inplanes)
567
+
568
+ self.relu_p = nn.ReLU(inplace=True)
569
+
570
+ img_size_p = img_size // 4
571
+
572
+ self.layer1_p = self._make_layer(block_2, int(128 * s), layers[0], kernel_size= (img_size_p//2))
573
+ self.layer2_p = self._make_layer(block_2, int(256 * s), layers[1], stride=2, kernel_size=(img_size_p//2),
574
+ dilate=replace_stride_with_dilation[0])
575
+ self.layer3_p = self._make_layer(block_2, int(512 * s), layers[2], stride=2, kernel_size=(img_size_p//4),
576
+ dilate=replace_stride_with_dilation[1])
577
+ self.layer4_p = self._make_layer(block_2, int(1024 * s), layers[3], stride=2, kernel_size=(img_size_p//8),
578
+ dilate=replace_stride_with_dilation[2])
579
+
580
+ # Decoder
581
+ self.decoder1_p = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
582
+ self.decoder2_p = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
583
+ self.decoder3_p = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
584
+ self.decoder4_p = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
585
+ self.decoder5_p = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
586
+
587
+ self.decoderf = nn.Conv2d(int(128*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
588
+ self.adjust_p = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
589
+ self.soft_p = nn.Softmax(dim=1)
590
+
591
+
592
+ def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
593
+ norm_layer = self._norm_layer
594
+ downsample = None
595
+ previous_dilation = self.dilation
596
+ if dilate:
597
+ self.dilation *= stride
598
+ stride = 1
599
+ if stride != 1 or self.inplanes != planes * block.expansion:
600
+ downsample = nn.Sequential(
601
+ conv1x1(self.inplanes, planes * block.expansion, stride),
602
+ norm_layer(planes * block.expansion),
603
+ )
604
+
605
+ layers = []
606
+ layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
607
+ base_width=self.base_width, dilation=previous_dilation,
608
+ norm_layer=norm_layer, kernel_size=kernel_size))
609
+ self.inplanes = planes * block.expansion
610
+ if stride != 1:
611
+ kernel_size = kernel_size // 2
612
+
613
+ for _ in range(1, blocks):
614
+ layers.append(block(self.inplanes, planes, groups=self.groups,
615
+ base_width=self.base_width, dilation=self.dilation,
616
+ norm_layer=norm_layer, kernel_size=kernel_size))
617
+
618
+ return nn.Sequential(*layers)
619
+
620
+ def _forward_impl(self, x):
621
+
622
+ xin = x.clone()
623
+ x = self.conv1(x)
624
+ x = self.bn1(x)
625
+ x = self.relu(x)
626
+ x = self.conv2(x)
627
+ x = self.bn2(x)
628
+ x = self.relu(x)
629
+ x = self.conv3(x)
630
+ x = self.bn3(x)
631
+ # x = F.max_pool2d(x,2,2)
632
+ x = self.relu(x)
633
+
634
+ # x = self.maxpool(x)
635
+ # pdb.set_trace()
636
+ x1 = self.layer1(x)
637
+ # print(x1.shape)
638
+ x2 = self.layer2(x1)
639
+ # print(x2.shape)
640
+ # x3 = self.layer3(x2)
641
+ # # print(x3.shape)
642
+ # x4 = self.layer4(x3)
643
+ # # print(x4.shape)
644
+ # x = F.relu(F.interpolate(self.decoder1(x4), scale_factor=(2,2), mode ='bilinear'))
645
+ # x = torch.add(x, x4)
646
+ # x = F.relu(F.interpolate(self.decoder2(x4) , scale_factor=(2,2), mode ='bilinear'))
647
+ # x = torch.add(x, x3)
648
+ # x = F.relu(F.interpolate(self.decoder3(x3) , scale_factor=(2,2), mode ='bilinear'))
649
+ # x = torch.add(x, x2)
650
+ x = F.relu(F.interpolate(self.decoder4(x2) , scale_factor=(2,2), mode ='bilinear'))
651
+ x = torch.add(x, x1)
652
+ x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
653
+ # print(x.shape)
654
+
655
+ # end of full image training
656
+
657
+ # y_out = torch.ones((1,2,128,128))
658
+ x_loc = x.clone()
659
+ # x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
660
+ #start
661
+ for i in range(0,4):
662
+ for j in range(0,4):
663
+
664
+ x_p = xin[:,:,32*i:32*(i+1),32*j:32*(j+1)]
665
+ # begin patch wise
666
+ x_p = self.conv1_p(x_p)
667
+ x_p = self.bn1_p(x_p)
668
+ # x = F.max_pool2d(x,2,2)
669
+ x_p = self.relu(x_p)
670
+
671
+ x_p = self.conv2_p(x_p)
672
+ x_p = self.bn2_p(x_p)
673
+ # x = F.max_pool2d(x,2,2)
674
+ x_p = self.relu(x_p)
675
+ x_p = self.conv3_p(x_p)
676
+ x_p = self.bn3_p(x_p)
677
+ # x = F.max_pool2d(x,2,2)
678
+ x_p = self.relu(x_p)
679
+
680
+ # x = self.maxpool(x)
681
+ # pdb.set_trace()
682
+ x1_p = self.layer1_p(x_p)
683
+ # print(x1.shape)
684
+ x2_p = self.layer2_p(x1_p)
685
+ # print(x2.shape)
686
+ x3_p = self.layer3_p(x2_p)
687
+ # # print(x3.shape)
688
+ x4_p = self.layer4_p(x3_p)
689
+
690
+ x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear'))
691
+ x_p = torch.add(x_p, x4_p)
692
+ x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
693
+ x_p = torch.add(x_p, x3_p)
694
+ x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
695
+ x_p = torch.add(x_p, x2_p)
696
+ x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
697
+ x_p = torch.add(x_p, x1_p)
698
+ x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
699
+
700
+ x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p
701
+
702
+ x = torch.add(x,x_loc)
703
+ x = F.relu(self.decoderf(x))
704
+
705
+ x = self.adjust(F.relu(x))
706
+
707
+ # pdb.set_trace()
708
+ return x
709
+
710
+ def forward(self, x, text_dummy):
711
+ return self.soft(self._forward_impl(x)),0
712
+
713
+
714
+ def axialunet(pretrained=False, **kwargs):
715
+ model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
716
+ return model
717
+
718
+ def gated(pretrained=False, **kwargs):
719
+ model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs)
720
+ return model
721
+
722
+ def MedT(pretrained=False, **kwargs):
723
+ model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125, **kwargs)
724
+ return model
725
+
726
+ def logo(pretrained=False, **kwargs):
727
+ model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)
728
+ return model
729
+
730
+ # EOF
AllinonSAM/baselines.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from backbones_unet.model.unet import Unet
4
+ import torch.nn.functional as F
5
+ from utils import *
6
+ __all__ = ['UNext']
7
+
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+ import math
10
+
11
+ class UNet(nn.Module):
12
+ def __init__(self, in_channels = 3, out_channels = 1, init_features = 32, pretrained=True , back_bone=None):
13
+ super().__init__()
14
+ if back_bone is None:
15
+ self.model = torch.hub.load(
16
+ 'mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=in_channels, out_channels=out_channels,
17
+ init_features=init_features, pretrained=pretrained
18
+ )
19
+ else:
20
+ self.model = UNet(
21
+ in_channels= in_channels,
22
+ out_channels= out_channels,
23
+ backbone=back_bone
24
+ )
25
+
26
+ self.soft = nn.Softmax(dim =1)
27
+ def forward(self, x, text_dummy):
28
+ return self.soft(self.model(x)),0
29
+
30
+
31
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
32
+ """1x1 convolution"""
33
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)
34
+
35
+ class shiftmlp(nn.Module):
36
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5):
37
+ super().__init__()
38
+ out_features = out_features or in_features
39
+ hidden_features = hidden_features or in_features
40
+ self.dim = in_features
41
+ self.fc1 = nn.Linear(in_features, hidden_features)
42
+ self.dwconv = DWConv(hidden_features)
43
+ self.act = act_layer()
44
+ self.fc2 = nn.Linear(hidden_features, out_features)
45
+ self.drop = nn.Dropout(drop)
46
+
47
+ self.shift_size = shift_size
48
+ self.pad = shift_size // 2
49
+
50
+
51
+ self.apply(self._init_weights)
52
+
53
+ def _init_weights(self, m):
54
+ if isinstance(m, nn.Linear):
55
+ trunc_normal_(m.weight, std=.02)
56
+ if isinstance(m, nn.Linear) and m.bias is not None:
57
+ nn.init.constant_(m.bias, 0)
58
+ elif isinstance(m, nn.LayerNorm):
59
+ nn.init.constant_(m.bias, 0)
60
+ nn.init.constant_(m.weight, 1.0)
61
+ elif isinstance(m, nn.Conv2d):
62
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
63
+ fan_out //= m.groups
64
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
65
+ if m.bias is not None:
66
+ m.bias.data.zero_()
67
+
68
+
69
+ def forward(self, x, H, W):
70
+ # pdb.set_trace()
71
+ B, N, C = x.shape
72
+
73
+ xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
74
+ xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
75
+ xs = torch.chunk(xn, self.shift_size, 1)
76
+ x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
77
+ x_cat = torch.cat(x_shift, 1)
78
+ x_cat = torch.narrow(x_cat, 2, self.pad, H)
79
+ x_s = torch.narrow(x_cat, 3, self.pad, W)
80
+
81
+
82
+ x_s = x_s.reshape(B,C,H*W).contiguous()
83
+ x_shift_r = x_s.transpose(1,2)
84
+
85
+
86
+ x = self.fc1(x_shift_r)
87
+
88
+ x = self.dwconv(x, H, W)
89
+ x = self.act(x)
90
+ x = self.drop(x)
91
+
92
+ xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
93
+ xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
94
+ xs = torch.chunk(xn, self.shift_size, 1)
95
+ x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
96
+ x_cat = torch.cat(x_shift, 1)
97
+ x_cat = torch.narrow(x_cat, 2, self.pad, H)
98
+ x_s = torch.narrow(x_cat, 3, self.pad, W)
99
+ x_s = x_s.reshape(B,C,H*W).contiguous()
100
+ x_shift_c = x_s.transpose(1,2)
101
+
102
+ x = self.fc2(x_shift_c)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+
107
+
108
+ class shiftedBlock(nn.Module):
109
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
110
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
111
+ super().__init__()
112
+
113
+
114
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
115
+ self.norm2 = norm_layer(dim)
116
+ mlp_hidden_dim = int(dim * mlp_ratio)
117
+ self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
118
+ self.apply(self._init_weights)
119
+
120
+ def _init_weights(self, m):
121
+ if isinstance(m, nn.Linear):
122
+ trunc_normal_(m.weight, std=.02)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ nn.init.constant_(m.bias, 0)
127
+ nn.init.constant_(m.weight, 1.0)
128
+ elif isinstance(m, nn.Conv2d):
129
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+ fan_out //= m.groups
131
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
132
+ if m.bias is not None:
133
+ m.bias.data.zero_()
134
+
135
+ def forward(self, x, H, W):
136
+
137
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
138
+ return x
139
+
140
+
141
+ class DWConv(nn.Module):
142
+ def __init__(self, dim=768):
143
+ super(DWConv, self).__init__()
144
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
145
+
146
+ def forward(self, x, H, W):
147
+ B, N, C = x.shape
148
+ x = x.transpose(1, 2).view(B, C, H, W)
149
+ x = self.dwconv(x)
150
+ x = x.flatten(2).transpose(1, 2)
151
+
152
+ return x
153
+
154
+ class OverlapPatchEmbed(nn.Module):
155
+ """ Image to Patch Embedding
156
+ """
157
+
158
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
159
+ super().__init__()
160
+ img_size = to_2tuple(img_size)
161
+ patch_size = to_2tuple(patch_size)
162
+
163
+ self.img_size = img_size
164
+ self.patch_size = patch_size
165
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
166
+ self.num_patches = self.H * self.W
167
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
168
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
169
+ self.norm = nn.LayerNorm(embed_dim)
170
+
171
+ self.apply(self._init_weights)
172
+
173
+ def _init_weights(self, m):
174
+ if isinstance(m, nn.Linear):
175
+ trunc_normal_(m.weight, std=.02)
176
+ if isinstance(m, nn.Linear) and m.bias is not None:
177
+ nn.init.constant_(m.bias, 0)
178
+ elif isinstance(m, nn.LayerNorm):
179
+ nn.init.constant_(m.bias, 0)
180
+ nn.init.constant_(m.weight, 1.0)
181
+ elif isinstance(m, nn.Conv2d):
182
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
183
+ fan_out //= m.groups
184
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
185
+ if m.bias is not None:
186
+ m.bias.data.zero_()
187
+
188
+ def forward(self, x):
189
+ x = self.proj(x)
190
+ _, _, H, W = x.shape
191
+ x = x.flatten(2).transpose(1, 2)
192
+ x = self.norm(x)
193
+
194
+ return x, H, W
195
+
196
+
197
+ class UNext(nn.Module):
198
+
199
+ ## Conv 3 + MLP 2 + shifted MLP
200
+
201
+ def __init__(self, num_classes, input_channels=3, deep_supervision=False,img_size=256, patch_size=16, in_chans=3, embed_dims=[ 128, 160, 256],
202
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
203
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
204
+ depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
205
+ super().__init__()
206
+
207
+ self.encoder1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
208
+ self.encoder2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
209
+ self.encoder3 = nn.Conv2d(32, 128, 3, stride=1, padding=1)
210
+
211
+ self.ebn1 = nn.BatchNorm2d(16)
212
+ self.ebn2 = nn.BatchNorm2d(32)
213
+ self.ebn3 = nn.BatchNorm2d(128)
214
+
215
+ self.norm3 = norm_layer(embed_dims[1])
216
+ self.norm4 = norm_layer(embed_dims[2])
217
+
218
+ self.dnorm3 = norm_layer(160)
219
+ self.dnorm4 = norm_layer(128)
220
+
221
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
222
+
223
+ self.block1 = nn.ModuleList([shiftedBlock(
224
+ dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
225
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
226
+ sr_ratio=sr_ratios[0])])
227
+
228
+ self.block2 = nn.ModuleList([shiftedBlock(
229
+ dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
230
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
231
+ sr_ratio=sr_ratios[0])])
232
+
233
+ self.dblock1 = nn.ModuleList([shiftedBlock(
234
+ dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
235
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
236
+ sr_ratio=sr_ratios[0])])
237
+
238
+ self.dblock2 = nn.ModuleList([shiftedBlock(
239
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
240
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
241
+ sr_ratio=sr_ratios[0])])
242
+
243
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
244
+ embed_dim=embed_dims[1])
245
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
246
+ embed_dim=embed_dims[2])
247
+
248
+ self.decoder1 = nn.Conv2d(256, 160, 3, stride=1,padding=1)
249
+ self.decoder2 = nn.Conv2d(160, 128, 3, stride=1, padding=1)
250
+ self.decoder3 = nn.Conv2d(128, 32, 3, stride=1, padding=1)
251
+ self.decoder4 = nn.Conv2d(32, 16, 3, stride=1, padding=1)
252
+ self.decoder5 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
253
+
254
+ self.dbn1 = nn.BatchNorm2d(160)
255
+ self.dbn2 = nn.BatchNorm2d(128)
256
+ self.dbn3 = nn.BatchNorm2d(32)
257
+ self.dbn4 = nn.BatchNorm2d(16)
258
+
259
+ self.final = nn.Conv2d(16, num_classes, kernel_size=1)
260
+
261
+ self.soft = nn.Softmax(dim =1)
262
+
263
+ def forward(self, x, text_dummy):
264
+
265
+ B = x.shape[0]
266
+ ### Encoder
267
+ ### Conv Stage
268
+
269
+ ### Stage 1
270
+ out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
271
+ t1 = out
272
+ ### Stage 2
273
+ out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
274
+ t2 = out
275
+ ### Stage 3
276
+ out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
277
+ t3 = out
278
+
279
+ ### Tokenized MLP Stage
280
+ ### Stage 4
281
+
282
+ out,H,W = self.patch_embed3(out)
283
+ for i, blk in enumerate(self.block1):
284
+ out = blk(out, H, W)
285
+ out = self.norm3(out)
286
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
287
+ t4 = out
288
+
289
+ ### Bottleneck
290
+
291
+ out ,H,W= self.patch_embed4(out)
292
+ for i, blk in enumerate(self.block2):
293
+ out = blk(out, H, W)
294
+ out = self.norm4(out)
295
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
296
+
297
+ ### Stage 4
298
+
299
+ out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
300
+
301
+ out = torch.add(out,t4)
302
+ _,_,H,W = out.shape
303
+ out = out.flatten(2).transpose(1,2)
304
+ for i, blk in enumerate(self.dblock1):
305
+ out = blk(out, H, W)
306
+
307
+ ### Stage 3
308
+
309
+ out = self.dnorm3(out)
310
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
311
+ out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
312
+ out = torch.add(out,t3)
313
+ _,_,H,W = out.shape
314
+ out = out.flatten(2).transpose(1,2)
315
+
316
+ for i, blk in enumerate(self.dblock2):
317
+ out = blk(out, H, W)
318
+
319
+ out = self.dnorm4(out)
320
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
321
+
322
+ out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
323
+ out = torch.add(out,t2)
324
+ out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
325
+ out = torch.add(out,t1)
326
+ out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))
327
+
328
+ return self.soft(self.final(out)),0
329
+
330
+
331
+ class UNext_S(nn.Module):
332
+
333
+ ## Conv 3 + MLP 2 + shifted MLP w less parameters
334
+
335
+ def __init__(self, num_classes, input_channels=3, deep_supervision=False,img_size=256, patch_size=16, in_chans=3, embed_dims=[32, 64, 128, 512],
336
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
337
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
338
+ depths=[1, 1, 1], sr_ratios=[8, 4, 2, 1], **kwargs):
339
+ super().__init__()
340
+
341
+ self.encoder1 = nn.Conv2d(3, 8, 3, stride=1, padding=1)
342
+ self.encoder2 = nn.Conv2d(8, 16, 3, stride=1, padding=1)
343
+ self.encoder3 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
344
+
345
+ self.ebn1 = nn.BatchNorm2d(8)
346
+ self.ebn2 = nn.BatchNorm2d(16)
347
+ self.ebn3 = nn.BatchNorm2d(32)
348
+
349
+ self.norm3 = norm_layer(embed_dims[1])
350
+ self.norm4 = norm_layer(embed_dims[2])
351
+
352
+ self.dnorm3 = norm_layer(64)
353
+ self.dnorm4 = norm_layer(32)
354
+
355
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
356
+
357
+ self.block1 = nn.ModuleList([shiftedBlock(
358
+ dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
359
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
360
+ sr_ratio=sr_ratios[0])])
361
+
362
+ self.block2 = nn.ModuleList([shiftedBlock(
363
+ dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
364
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
365
+ sr_ratio=sr_ratios[0])])
366
+
367
+ self.dblock1 = nn.ModuleList([shiftedBlock(
368
+ dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
369
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
370
+ sr_ratio=sr_ratios[0])])
371
+
372
+ self.dblock2 = nn.ModuleList([shiftedBlock(
373
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
374
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
375
+ sr_ratio=sr_ratios[0])])
376
+
377
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
378
+ embed_dim=embed_dims[1])
379
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
380
+ embed_dim=embed_dims[2])
381
+
382
+ self.decoder1 = nn.Conv2d(128, 64, 3, stride=1,padding=1)
383
+ self.decoder2 = nn.Conv2d(64, 32, 3, stride=1, padding=1)
384
+ self.decoder3 = nn.Conv2d(32, 16, 3, stride=1, padding=1)
385
+ self.decoder4 = nn.Conv2d(16, 8, 3, stride=1, padding=1)
386
+ self.decoder5 = nn.Conv2d(8, 8, 3, stride=1, padding=1)
387
+
388
+ self.dbn1 = nn.BatchNorm2d(64)
389
+ self.dbn2 = nn.BatchNorm2d(32)
390
+ self.dbn3 = nn.BatchNorm2d(16)
391
+ self.dbn4 = nn.BatchNorm2d(8)
392
+
393
+ self.final = nn.Conv2d(8, num_classes, kernel_size=1)
394
+
395
+ self.soft = nn.Softmax(dim =1)
396
+
397
+ def forward(self, x, text_dummy):
398
+
399
+ B = x.shape[0]
400
+ ### Encoder
401
+ ### Conv Stage
402
+
403
+ ### Stage 1
404
+ out = F.relu(F.max_pool2d(self.ebn1(self.encoder1(x)),2,2))
405
+ t1 = out
406
+ ### Stage 2
407
+ out = F.relu(F.max_pool2d(self.ebn2(self.encoder2(out)),2,2))
408
+ t2 = out
409
+ ### Stage 3
410
+ out = F.relu(F.max_pool2d(self.ebn3(self.encoder3(out)),2,2))
411
+ t3 = out
412
+
413
+ ### Tokenized MLP Stage
414
+ ### Stage 4
415
+
416
+ out,H,W = self.patch_embed3(out)
417
+ for i, blk in enumerate(self.block1):
418
+ out = blk(out, H, W)
419
+ out = self.norm3(out)
420
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
421
+ t4 = out
422
+
423
+ ### Bottleneck
424
+
425
+ out ,H,W= self.patch_embed4(out)
426
+ for i, blk in enumerate(self.block2):
427
+ out = blk(out, H, W)
428
+ out = self.norm4(out)
429
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
430
+
431
+ ### Stage 4
432
+
433
+ out = F.relu(F.interpolate(self.dbn1(self.decoder1(out)),scale_factor=(2,2),mode ='bilinear'))
434
+
435
+ out = torch.add(out,t4)
436
+ _,_,H,W = out.shape
437
+ out = out.flatten(2).transpose(1,2)
438
+ for i, blk in enumerate(self.dblock1):
439
+ out = blk(out, H, W)
440
+
441
+ ### Stage 3
442
+
443
+ out = self.dnorm3(out)
444
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
445
+ out = F.relu(F.interpolate(self.dbn2(self.decoder2(out)),scale_factor=(2,2),mode ='bilinear'))
446
+ out = torch.add(out,t3)
447
+ _,_,H,W = out.shape
448
+ out = out.flatten(2).transpose(1,2)
449
+
450
+ for i, blk in enumerate(self.dblock2):
451
+ out = blk(out, H, W)
452
+
453
+ out = self.dnorm4(out)
454
+ out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
455
+
456
+ out = F.relu(F.interpolate(self.dbn3(self.decoder3(out)),scale_factor=(2,2),mode ='bilinear'))
457
+ out = torch.add(out,t2)
458
+ out = F.relu(F.interpolate(self.dbn4(self.decoder4(out)),scale_factor=(2,2),mode ='bilinear'))
459
+ out = torch.add(out,t1)
460
+ out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))
461
+
462
+ return self.final(out)
463
+
464
+
465
+ class medt_net(nn.Module):
466
+
467
+ def __init__(self, block, block_2, layers, num_classes=2, zero_init_residual=True,
468
+ groups=8, width_per_group=64, replace_stride_with_dilation=None,
469
+ norm_layer=None, s=0.125, img_size = 128,imgchan = 3):
470
+ super(medt_net, self).__init__()
471
+ if norm_layer is None:
472
+ norm_layer = nn.BatchNorm2d
473
+ self._norm_layer = norm_layer
474
+
475
+ self.inplanes = int(64 * s)
476
+ self.dilation = 1
477
+ if replace_stride_with_dilation is None:
478
+ replace_stride_with_dilation = [False, False, False]
479
+ if len(replace_stride_with_dilation) != 3:
480
+ raise ValueError("replace_stride_with_dilation should be None "
481
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
482
+ self.groups = groups
483
+ self.base_width = width_per_group
484
+ self.conv1 = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
485
+ bias=False)
486
+ self.conv2 = nn.Conv2d(self.inplanes, 128, kernel_size=3, stride=1, padding=1, bias=False)
487
+ self.conv3 = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
488
+ self.bn1 = norm_layer(self.inplanes)
489
+ self.bn2 = norm_layer(128)
490
+ self.bn3 = norm_layer(self.inplanes)
491
+ self.bn1 = norm_layer(self.inplanes)
492
+ self.relu = nn.ReLU(inplace=True)
493
+ self.layer1 = self._make_layer(block, int(128 * s), layers[0], kernel_size= (img_size//2))
494
+ self.layer2 = self._make_layer(block, int(256 * s), layers[1], stride=2, kernel_size=(img_size//2),
495
+ dilate=replace_stride_with_dilation[0])
496
+
497
+ self.decoder4 = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
498
+ self.decoder5 = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
499
+ self.adjust = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
500
+ self.soft = nn.Softmax(dim=1)
501
+
502
+
503
+ self.conv1_p = nn.Conv2d(imgchan, self.inplanes, kernel_size=7, stride=2, padding=3,
504
+ bias=False)
505
+ self.conv2_p = nn.Conv2d(self.inplanes,128, kernel_size=3, stride=1, padding=1,
506
+ bias=False)
507
+ self.conv3_p = nn.Conv2d(128, self.inplanes, kernel_size=3, stride=1, padding=1,
508
+ bias=False)
509
+ self.bn1_p = norm_layer(self.inplanes)
510
+ self.bn2_p = norm_layer(128)
511
+ self.bn3_p = norm_layer(self.inplanes)
512
+
513
+ self.relu_p = nn.ReLU(inplace=True)
514
+
515
+ img_size_p = img_size // 4
516
+
517
+ self.layer1_p = self._make_layer(block_2, int(128 * s), layers[0], kernel_size= (img_size_p//2))
518
+ self.layer2_p = self._make_layer(block_2, int(256 * s), layers[1], stride=2, kernel_size=(img_size_p//2),
519
+ dilate=replace_stride_with_dilation[0])
520
+ self.layer3_p = self._make_layer(block_2, int(512 * s), layers[2], stride=2, kernel_size=(img_size_p//4),
521
+ dilate=replace_stride_with_dilation[1])
522
+ self.layer4_p = self._make_layer(block_2, int(1024 * s), layers[3], stride=2, kernel_size=(img_size_p//8),
523
+ dilate=replace_stride_with_dilation[2])
524
+
525
+ # Decoder
526
+ self.decoder1_p = nn.Conv2d(int(1024 *2*s) , int(1024*2*s), kernel_size=3, stride=2, padding=1)
527
+ self.decoder2_p = nn.Conv2d(int(1024 *2*s) , int(1024*s), kernel_size=3, stride=1, padding=1)
528
+ self.decoder3_p = nn.Conv2d(int(1024*s), int(512*s), kernel_size=3, stride=1, padding=1)
529
+ self.decoder4_p = nn.Conv2d(int(512*s) , int(256*s), kernel_size=3, stride=1, padding=1)
530
+ self.decoder5_p = nn.Conv2d(int(256*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
531
+
532
+ self.decoderf = nn.Conv2d(int(128*s) , int(128*s) , kernel_size=3, stride=1, padding=1)
533
+ self.adjust_p = nn.Conv2d(int(128*s) , num_classes, kernel_size=1, stride=1, padding=0)
534
+ self.soft_p = nn.Softmax(dim=1)
535
+
536
+
537
+ def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False):
538
+ norm_layer = self._norm_layer
539
+ downsample = None
540
+ previous_dilation = self.dilation
541
+ if dilate:
542
+ self.dilation *= stride
543
+ stride = 1
544
+ if stride != 1 or self.inplanes != planes * block.expansion:
545
+ downsample = nn.Sequential(
546
+ conv1x1(self.inplanes, planes * block.expansion, stride),
547
+ norm_layer(planes * block.expansion),
548
+ )
549
+
550
+ layers = []
551
+ layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups,
552
+ base_width=self.base_width, dilation=previous_dilation,
553
+ norm_layer=norm_layer, kernel_size=kernel_size))
554
+ self.inplanes = planes * block.expansion
555
+ if stride != 1:
556
+ kernel_size = kernel_size // 2
557
+
558
+ for _ in range(1, blocks):
559
+ layers.append(block(self.inplanes, planes, groups=self.groups,
560
+ base_width=self.base_width, dilation=self.dilation,
561
+ norm_layer=norm_layer, kernel_size=kernel_size))
562
+
563
+ return nn.Sequential(*layers)
564
+
565
+ def _forward_impl(self, x):
566
+
567
+ xin = x.clone()
568
+ x = self.conv1(x)
569
+ x = self.bn1(x)
570
+ x = self.relu(x)
571
+ x = self.conv2(x)
572
+ x = self.bn2(x)
573
+ x = self.relu(x)
574
+ x = self.conv3(x)
575
+ x = self.bn3(x)
576
+ x = self.relu(x)
577
+
578
+ x1 = self.layer1(x)
579
+ x2 = self.layer2(x1)
580
+
581
+ x = F.relu(F.interpolate(self.decoder4(x2) , scale_factor=(2,2), mode ='bilinear'))
582
+ x = torch.add(x, x1)
583
+ x = F.relu(F.interpolate(self.decoder5(x) , scale_factor=(2,2), mode ='bilinear'))
584
+
585
+ # end of full image training
586
+
587
+ x_loc = x.clone()
588
+ #start
589
+ for i in range(0,4):
590
+ for j in range(0,4):
591
+
592
+ x_p = xin[:,:,32*i:32*(i+1),32*j:32*(j+1)]
593
+ # begin patch wise
594
+ x_p = self.conv1_p(x_p)
595
+ x_p = self.bn1_p(x_p)
596
+ x_p = self.relu(x_p)
597
+
598
+ x_p = self.conv2_p(x_p)
599
+ x_p = self.bn2_p(x_p)
600
+ x_p = self.relu(x_p)
601
+ x_p = self.conv3_p(x_p)
602
+ x_p = self.bn3_p(x_p)
603
+ x_p = self.relu(x_p)
604
+
605
+ x1_p = self.layer1_p(x_p)
606
+ x2_p = self.layer2_p(x1_p)
607
+ x3_p = self.layer3_p(x2_p)
608
+ x4_p = self.layer4_p(x3_p)
609
+
610
+ x_p = F.relu(F.interpolate(self.decoder1_p(x4_p), scale_factor=(2,2), mode ='bilinear'))
611
+ x_p = torch.add(x_p, x4_p)
612
+ x_p = F.relu(F.interpolate(self.decoder2_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
613
+ x_p = torch.add(x_p, x3_p)
614
+ x_p = F.relu(F.interpolate(self.decoder3_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
615
+ x_p = torch.add(x_p, x2_p)
616
+ x_p = F.relu(F.interpolate(self.decoder4_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
617
+ x_p = torch.add(x_p, x1_p)
618
+ x_p = F.relu(F.interpolate(self.decoder5_p(x_p) , scale_factor=(2,2), mode ='bilinear'))
619
+
620
+ x_loc[:,:,32*i:32*(i+1),32*j:32*(j+1)] = x_p
621
+
622
+ x = torch.add(x,x_loc)
623
+ x = F.relu(self.decoderf(x))
624
+
625
+ x = self.adjust(F.relu(x))
626
+
627
+ return x
628
+
629
+ def forward(self, x, text_dummy):
630
+ return self._forward_impl(x)
AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_0_batch_1_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_100_batch_1_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_10_batch_1_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_110_batch_1_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_120_batch_1_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_130_batch_1_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_0_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_0_img_1.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_1_img_0.png ADDED
AllinonSAM/biastuning/DIAS/labels/epoch_140_batch_1_img_1.png ADDED