upload
Browse files- alignedthreeattn_model.py +23 -20
alignedthreeattn_model.py
CHANGED
@@ -21,25 +21,25 @@ class ThreeAttnNodes(nn.Module):
|
|
21 |
backbone.requires_grad_(False)
|
22 |
backbone.eval()
|
23 |
|
24 |
-
def resample_position_embeddings(embeddings, h, w):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
pos_embd = self.backbone1.model.visual.positional_embedding
|
35 |
-
pos_embd = resample_position_embeddings(pos_embd, 42, 42)
|
36 |
-
self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
|
37 |
|
38 |
-
pos_embed = self.backbone3.model.pos_embed[0]
|
39 |
-
pos_embed = resample_position_embeddings(pos_embed, 42, 42)
|
40 |
-
self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
41 |
-
self.backbone3.model.img_size = (672, 672)
|
42 |
-
self.backbone3.model.patch_embed.img_size = (672, 672)
|
43 |
|
44 |
|
45 |
self.align_weights = nn.Parameter(align_weights)
|
@@ -47,11 +47,13 @@ class ThreeAttnNodes(nn.Module):
|
|
47 |
@torch.no_grad()
|
48 |
def forward(self, x):
|
49 |
# resize x to 672x672
|
50 |
-
x = F.interpolate(x, size=(672, 672), mode="bilinear")
|
|
|
51 |
feat1 = self.backbone1(x)
|
52 |
feat3 = self.backbone3(x)
|
53 |
# resize x to 588x588
|
54 |
-
x = F.interpolate(x, size=(588, 588), mode="bilinear")
|
|
|
55 |
feat2 = self.backbone2(x)
|
56 |
feats = torch.cat([feat1, feat2, feat3], dim=1)
|
57 |
# out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
|
@@ -60,6 +62,7 @@ class ThreeAttnNodes(nn.Module):
|
|
60 |
out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
|
61 |
outs.append(out)
|
62 |
out = torch.stack(outs, dim=1)
|
63 |
-
|
|
|
64 |
return out
|
65 |
|
|
|
21 |
backbone.requires_grad_(False)
|
22 |
backbone.eval()
|
23 |
|
24 |
+
# def resample_position_embeddings(embeddings, h, w):
|
25 |
+
# cls_embeddings = embeddings[0]
|
26 |
+
# patch_embeddings = embeddings[1:] # [14*14, 768]
|
27 |
+
# hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
|
28 |
+
# patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
|
29 |
+
# patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
|
30 |
+
# patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
31 |
+
# embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
32 |
+
# return embeddings
|
33 |
|
34 |
+
# pos_embd = self.backbone1.model.visual.positional_embedding
|
35 |
+
# pos_embd = resample_position_embeddings(pos_embd, 42, 42)
|
36 |
+
# self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
|
37 |
|
38 |
+
# pos_embed = self.backbone3.model.pos_embed[0]
|
39 |
+
# pos_embed = resample_position_embeddings(pos_embed, 42, 42)
|
40 |
+
# self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
41 |
+
# self.backbone3.model.img_size = (672, 672)
|
42 |
+
# self.backbone3.model.patch_embed.img_size = (672, 672)
|
43 |
|
44 |
|
45 |
self.align_weights = nn.Parameter(align_weights)
|
|
|
47 |
@torch.no_grad()
|
48 |
def forward(self, x):
|
49 |
# resize x to 672x672
|
50 |
+
# x = F.interpolate(x, size=(672, 672), mode="bilinear")
|
51 |
+
x = F.interpolate(x, size=(224, 224), mode="bilinear")
|
52 |
feat1 = self.backbone1(x)
|
53 |
feat3 = self.backbone3(x)
|
54 |
# resize x to 588x588
|
55 |
+
# x = F.interpolate(x, size=(588, 588), mode="bilinear")
|
56 |
+
x = F.interpolate(x, size=(196, 196), mode="bilinear")
|
57 |
feat2 = self.backbone2(x)
|
58 |
feats = torch.cat([feat1, feat2, feat3], dim=1)
|
59 |
# out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
|
|
|
62 |
out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
|
63 |
outs.append(out)
|
64 |
out = torch.stack(outs, dim=1)
|
65 |
+
hw = np.sqrt(out.shape[2]-1).astype(int)
|
66 |
+
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=hw, w=hw)
|
67 |
return out
|
68 |
|