huzey commited on
Commit
ecba583
·
1 Parent(s): 3b4de15
Files changed (1) hide show
  1. alignedthreeattn_model.py +6 -1
alignedthreeattn_model.py CHANGED
@@ -54,7 +54,12 @@ class ThreeAttnNodes(nn.Module):
54
  x = F.interpolate(x, size=(588, 588), mode="bilinear")
55
  feat2 = self.backbone2(x)
56
  feats = torch.cat([feat1, feat2, feat3], dim=1)
57
- out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
 
 
 
 
 
58
  out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
59
  return out
60
 
 
54
  x = F.interpolate(x, size=(588, 588), mode="bilinear")
55
  feat2 = self.backbone2(x)
56
  feats = torch.cat([feat1, feat2, feat3], dim=1)
57
+ # out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
58
+ outs = []
59
+ for i_layer in range(36):
60
+ out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
61
+ outs.append(out)
62
+ out = torch.stack(outs, dim=1)
63
  out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
64
  return out
65