|
|
|
import os |
|
import torch |
|
from PIL import Image |
|
from einops import rearrange, repeat |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
from torch import nn |
|
from alignedthreeattn_backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode |
|
class ThreeAttnNodes(nn.Module): |
|
def __init__(self, align_weights): |
|
super().__init__() |
|
self.backbone1 = CLIPAttnNode() |
|
self.backbone2 = DiNOv2AttnNode() |
|
self.backbone3 = MAEAttnNode() |
|
for backbone in [self.backbone1, self.backbone2, self.backbone3]: |
|
backbone.requires_grad_(False) |
|
backbone.eval() |
|
|
|
def resample_position_embeddings(embeddings, h, w): |
|
cls_embeddings = embeddings[0] |
|
patch_embeddings = embeddings[1:] |
|
hw = np.sqrt(patch_embeddings.shape[0]).astype(int) |
|
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw) |
|
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0) |
|
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c") |
|
embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0) |
|
return embeddings |
|
|
|
pos_embd = self.backbone1.model.visual.positional_embedding |
|
pos_embd = resample_position_embeddings(pos_embd, 42, 42) |
|
self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd) |
|
|
|
pos_embed = self.backbone3.model.pos_embed[0] |
|
pos_embed = resample_position_embeddings(pos_embed, 42, 42) |
|
self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0)) |
|
self.backbone3.model.img_size = (672, 672) |
|
self.backbone3.model.patch_embed.img_size = (672, 672) |
|
|
|
|
|
self.align_weights = nn.Parameter(align_weights) |
|
|
|
@torch.no_grad() |
|
def forward(self, x): |
|
|
|
x = F.interpolate(x, size=(672, 672), mode="bilinear") |
|
feat1 = self.backbone1(x) |
|
feat3 = self.backbone3(x) |
|
|
|
x = F.interpolate(x, size=(588, 588), mode="bilinear") |
|
feat2 = self.backbone2(x) |
|
feats = torch.cat([feat1, feat2, feat3], dim=1) |
|
out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights) |
|
out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42) |
|
return out |
|
|
|
|