File size: 2,959 Bytes
7acde1f c37e7c7 7acde1f a47351a 7acde1f c37e7c7 7acde1f 6daedba fd3784d 6daedba fd3784d 6daedba fd3784d 6daedba 3b4de15 7acde1f fd3784d 7acde1f fd3784d 7acde1f ecba583 fd3784d 7acde1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# %%
import os
import torch
from PIL import Image
from einops import rearrange, repeat
import numpy as np
import torch
import torch.nn.functional as F
# align_weights = torch.load("align_weights.pth")
from torch import nn
from alignedthreeattn_backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode
class ThreeAttnNodes(nn.Module):
def __init__(self, align_weights):
super().__init__()
self.backbone1 = CLIPAttnNode()
self.backbone2 = DiNOv2AttnNode()
self.backbone3 = MAEAttnNode()
for backbone in [self.backbone1, self.backbone2, self.backbone3]:
backbone.requires_grad_(False)
backbone.eval()
# def resample_position_embeddings(embeddings, h, w):
# cls_embeddings = embeddings[0]
# patch_embeddings = embeddings[1:] # [14*14, 768]
# hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
# patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
# patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
# patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
# embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
# return embeddings
# pos_embd = self.backbone1.model.visual.positional_embedding
# pos_embd = resample_position_embeddings(pos_embd, 42, 42)
# self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd)
# pos_embed = self.backbone3.model.pos_embed[0]
# pos_embed = resample_position_embeddings(pos_embed, 42, 42)
# self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
# self.backbone3.model.img_size = (672, 672)
# self.backbone3.model.patch_embed.img_size = (672, 672)
self.align_weights = nn.Parameter(align_weights)
@torch.no_grad()
def forward(self, x):
# resize x to 672x672
# x = F.interpolate(x, size=(672, 672), mode="bilinear")
x = F.interpolate(x, size=(224, 224), mode="bilinear")
feat1 = self.backbone1(x)
feat3 = self.backbone3(x)
# resize x to 588x588
# x = F.interpolate(x, size=(588, 588), mode="bilinear")
x = F.interpolate(x, size=(196, 196), mode="bilinear")
feat2 = self.backbone2(x)
feats = torch.cat([feat1, feat2, feat3], dim=1)
# out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights)
outs = []
for i_layer in range(36):
out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer])
outs.append(out)
out = torch.stack(outs, dim=1)
hw = np.sqrt(out.shape[2]-1).astype(int)
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=hw, w=hw)
return out
|