File size: 2,959 Bytes
7acde1f
 
 
 
 
 
 
 
 
 
c37e7c7
7acde1f
a47351a
7acde1f
c37e7c7
7acde1f
 
 
 
 
 
 
6daedba
fd3784d
 
 
 
 
 
 
 
 
6daedba
fd3784d
 
 
6daedba
fd3784d
 
 
 
 
6daedba
 
3b4de15
7acde1f
 
 
 
fd3784d
 
7acde1f
 
 
fd3784d
 
7acde1f
 
ecba583
 
 
 
 
 
fd3784d
 
7acde1f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# %%
import os
import torch
from PIL import Image
from einops import rearrange, repeat
import numpy as np

import torch
import torch.nn.functional as F

# align_weights = torch.load("align_weights.pth")
from torch import nn
from alignedthreeattn_backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode
class ThreeAttnNodes(nn.Module):
    def __init__(self, align_weights):
        super().__init__()
        self.backbone1 = CLIPAttnNode()
        self.backbone2 = DiNOv2AttnNode()
        self.backbone3 = MAEAttnNode()
        for backbone in [self.backbone1, self.backbone2, self.backbone3]:
            backbone.requires_grad_(False)
            backbone.eval()
                    
        # def resample_position_embeddings(embeddings, h, w):
        #     cls_embeddings = embeddings[0]
        #     patch_embeddings = embeddings[1:]  # [14*14, 768]
        #     hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
        #     patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
        #     patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
        #     patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
        #     embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
        #     return embeddings

        # pos_embd = self.backbone1.model.visual.positional_embedding 
        # pos_embd = resample_position_embeddings(pos_embd, 42, 42)
        # self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)

        # pos_embed = self.backbone3.model.pos_embed[0]
        # pos_embed = resample_position_embeddings(pos_embed, 42, 42)
        # self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
        # self.backbone3.model.img_size = (672, 672)
        # self.backbone3.model.patch_embed.img_size = (672, 672)
            
            
        self.align_weights = nn.Parameter(align_weights)
            
    @torch.no_grad()
    def forward(self, x):
        # resize x to 672x672
        # x = F.interpolate(x, size=(672, 672), mode="bilinear")
        x = F.interpolate(x, size=(224, 224), mode="bilinear")
        feat1 = self.backbone1(x)
        feat3 = self.backbone3(x)
        # resize x to 588x588
        # x = F.interpolate(x, size=(588, 588), mode="bilinear")
        x = F.interpolate(x, size=(196, 196), mode="bilinear")
        feat2 = self.backbone2(x)
        feats = torch.cat([feat1, feat2, feat3], dim=1)
        # out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
        outs = []
        for i_layer in range(36):
            out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
            outs.append(out)
        out = torch.stack(outs, dim=1)
        hw = np.sqrt(out.shape[2]-1).astype(int)
        out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=hw, w=hw)
        return out