upload
Browse files
alignedthreeattn_model.py
CHANGED
@@ -54,7 +54,12 @@ class ThreeAttnNodes(nn.Module):
|
|
54 |
x = F.interpolate(x, size=(588, 588), mode="bilinear")
|
55 |
feat2 = self.backbone2(x)
|
56 |
feats = torch.cat([feat1, feat2, feat3], dim=1)
|
57 |
-
out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
|
|
|
|
|
|
|
|
|
|
|
58 |
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
|
59 |
return out
|
60 |
|
|
|
54 |
x = F.interpolate(x, size=(588, 588), mode="bilinear")
|
55 |
feat2 = self.backbone2(x)
|
56 |
feats = torch.cat([feat1, feat2, feat3], dim=1)
|
57 |
+
# out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
|
58 |
+
outs = []
|
59 |
+
for i_layer in range(36):
|
60 |
+
out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
|
61 |
+
outs.append(out)
|
62 |
+
out = torch.stack(outs, dim=1)
|
63 |
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42)
|
64 |
return out
|
65 |
|