upload
Browse files- alignedthreeattn_model.py +22 -0
alignedthreeattn_model.py
CHANGED
@@ -20,6 +20,28 @@ class ThreeAttnNodes(nn.Module):
|
|
20 |
for backbone in [self.backbone1, self.backbone2, self.backbone3]:
|
21 |
backbone.requires_grad_(False)
|
22 |
backbone.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
self.align_weights = align_weights
|
24 |
|
25 |
@torch.no_grad()
|
|
|
20 |
for backbone in [self.backbone1, self.backbone2, self.backbone3]:
|
21 |
backbone.requires_grad_(False)
|
22 |
backbone.eval()
|
23 |
+
|
24 |
+
def resample_position_embeddings(embeddings, h, w):
|
25 |
+
cls_embeddings = embeddings[0]
|
26 |
+
patch_embeddings = embeddings[1:] # [14*14, 768]
|
27 |
+
hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
|
28 |
+
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
|
29 |
+
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
|
30 |
+
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
31 |
+
embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
32 |
+
return embeddings
|
33 |
+
|
34 |
+
pos_embd = self.backbone1.model.visual.positional_embedding
|
35 |
+
pos_embd = resample_position_embeddings(pos_embd, 42, 42)
|
36 |
+
self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
|
37 |
+
|
38 |
+
pos_embed = self.backbone3.model.pos_embed[0]
|
39 |
+
pos_embed = resample_position_embeddings(pos_embed, 42, 42)
|
40 |
+
self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
41 |
+
self.backbone3.model.img_size = (672, 672)
|
42 |
+
self.backbone3.model.patch_embed.img_size = (672, 672)
|
43 |
+
|
44 |
+
|
45 |
self.align_weights = align_weights
|
46 |
|
47 |
@torch.no_grad()
|