# %% import os import torch from PIL import Image from einops import rearrange, repeat import numpy as np import torch import torch.nn.functional as F align_weights = torch.load("align_weights.pth") from torch import nn from backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode class ThreeAttnNodes(nn.Module): def __init__(self, align_weights=align_weights): super().__init__() self.backbone1 = CLIPAttnNode() self.backbone2 = DiNOv2AttnNode() self.backbone3 = MAEAttnNode() for backbone in [self.backbone1, self.backbone2, self.backbone3]: backbone.requires_grad_(False) backbone.eval() self.align_weights = align_weights @torch.no_grad() def forward(self, x): # resize x to 672x672 x = F.interpolate(x, size=(672, 672), mode="bilinear") feat1 = self.backbone1(x) feat3 = self.backbone3(x) # resize x to 588x588 x = F.interpolate(x, size=(588, 588), mode="bilinear") feat2 = self.backbone2(x) feats = torch.cat([feat1, feat2, feat3], dim=1) out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights) out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=42, w=42) return out