huzey commited on
Commit
6daedba
·
1 Parent(s): a47351a
Files changed (1) hide show
  1. alignedthreeattn_model.py +22 -0
alignedthreeattn_model.py CHANGED
@@ -20,6 +20,28 @@ class ThreeAttnNodes(nn.Module):
20
  for backbone in [self.backbone1, self.backbone2, self.backbone3]:
21
  backbone.requires_grad_(False)
22
  backbone.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  self.align_weights = align_weights
24
 
25
  @torch.no_grad()
 
20
  for backbone in [self.backbone1, self.backbone2, self.backbone3]:
21
  backbone.requires_grad_(False)
22
  backbone.eval()
23
+
24
+ def resample_position_embeddings(embeddings, h, w):
25
+ cls_embeddings = embeddings[0]
26
+ patch_embeddings = embeddings[1:] # [14*14, 768]
27
+ hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
28
+ patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
29
+ patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
30
+ patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
31
+ embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
32
+ return embeddings
33
+
34
+ pos_embd = self.backbone1.model.visual.positional_embedding
35
+ pos_embd = resample_position_embeddings(pos_embd, 42, 42)
36
+ self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
37
+
38
+ pos_embed = self.backbone3.model.pos_embed[0]
39
+ pos_embed = resample_position_embeddings(pos_embed, 42, 42)
40
+ self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
41
+ self.backbone3.model.img_size = (672, 672)
42
+ self.backbone3.model.patch_embed.img_size = (672, 672)
43
+
44
+
45
  self.align_weights = align_weights
46
 
47
  @torch.no_grad()