huzey commited on
Commit
fd3784d
·
1 Parent(s): ecba583
Files changed (1) hide show
  1. alignedthreeattn_model.py +23 -20
alignedthreeattn_model.py CHANGED
@@ -21,25 +21,25 @@ class ThreeAttnNodes(nn.Module):
21
  backbone.requires_grad_(False)
22
  backbone.eval()
23
 
24
- def resample_position_embeddings(embeddings, h, w):
25
- cls_embeddings = embeddings[0]
26
- patch_embeddings = embeddings[1:] # [14*14, 768]
27
- hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
28
- patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
29
- patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
30
- patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
31
- embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
32
- return embeddings
33
 
34
- pos_embd = self.backbone1.model.visual.positional_embedding
35
- pos_embd = resample_position_embeddings(pos_embd, 42, 42)
36
- self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
37
 
38
- pos_embed = self.backbone3.model.pos_embed[0]
39
- pos_embed = resample_position_embeddings(pos_embed, 42, 42)
40
- self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
41
- self.backbone3.model.img_size = (672, 672)
42
- self.backbone3.model.patch_embed.img_size = (672, 672)
43
 
44
 
45
  self.align_weights = nn.Parameter(align_weights)
@@ -47,11 +47,13 @@ class ThreeAttnNodes(nn.Module):
47
  @torch.no_grad()
48
  def forward(self, x):
49
  # resize x to 672x672
50
- x = F.interpolate(x, size=(672, 672), mode="bilinear")
 
51
  feat1 = self.backbone1(x)
52
  feat3 = self.backbone3(x)
53
  # resize x to 588x588
54
- x = F.interpolate(x, size=(588, 588), mode="bilinear")
 
55
  feat2 = self.backbone2(x)
56
  feats = torch.cat([feat1, feat2, feat3], dim=1)
57
  # out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
@@ -60,6 +62,7 @@ class ThreeAttnNodes(nn.Module):
60
  out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
61
  outs.append(out)
62
  out = torch.stack(outs, dim=1)
63
- out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
 
64
  return out
65
 
 
21
  backbone.requires_grad_(False)
22
  backbone.eval()
23
 
24
+ # def resample_position_embeddings(embeddings, h, w):
25
+ # cls_embeddings = embeddings[0]
26
+ # patch_embeddings = embeddings[1:] # [14*14, 768]
27
+ # hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
28
+ # patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
29
+ # patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
30
+ # patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
31
+ # embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
32
+ # return embeddings
33
 
34
+ # pos_embd = self.backbone1.model.visual.positional_embedding
35
+ # pos_embd = resample_position_embeddings(pos_embd, 42, 42)
36
+ # self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
37
 
38
+ # pos_embed = self.backbone3.model.pos_embed[0]
39
+ # pos_embed = resample_position_embeddings(pos_embed, 42, 42)
40
+ # self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
41
+ # self.backbone3.model.img_size = (672, 672)
42
+ # self.backbone3.model.patch_embed.img_size = (672, 672)
43
 
44
 
45
  self.align_weights = nn.Parameter(align_weights)
 
47
  @torch.no_grad()
48
  def forward(self, x):
49
  # resize x to 672x672
50
+ # x = F.interpolate(x, size=(672, 672), mode="bilinear")
51
+ x = F.interpolate(x, size=(224, 224), mode="bilinear")
52
  feat1 = self.backbone1(x)
53
  feat3 = self.backbone3(x)
54
  # resize x to 588x588
55
+ # x = F.interpolate(x, size=(588, 588), mode="bilinear")
56
+ x = F.interpolate(x, size=(196, 196), mode="bilinear")
57
  feat2 = self.backbone2(x)
58
  feats = torch.cat([feat1, feat2, feat3], dim=1)
59
  # out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
 
62
  out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
63
  outs.append(out)
64
  out = torch.stack(outs, dim=1)
65
+ hw = np.sqrt(out.shape[2]-1).astype(int)
66
+ out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=hw, w=hw)
67
  return out
68