from __future__ import print_function import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR import torchvision.models as models import timm from pprint import pprint import numpy as np from tqdm import tqdm from torch.utils.data.sampler import BatchSampler from gradient_reversal.module import GradientReversal class SwinModel(nn.Module): def __init__(self): super().__init__() self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0) self.swin_cb = self.swin_cl self.linear_cl = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) self.linear_cb = nn.Linear(1024, 1024) def freeze_encoder(self): for param in self.swin_cl.parameters(): param.requires_grad = False for param in self.swin_cb.parameters(): param.requires_grad = False def unfreeze_encoder(self): for param in self.swin_cl.parameters(): param.requires_grad = True for param in self.swin_cb.parameters(): param.requires_grad = True def get_embeddings(self, image, ftype): linear = self.linear_cl if ftype == "contactless" else self.linear_cl swin = self.swin_cl if ftype == "contactless" else self.swin_cb tokens = swin(image) emb_mean = tokens.mean(dim=1) feat = linear(emb_mean) tokens_transformed = linear(tokens) return feat, tokens def forward(self, x_cl, x_cb): x_cl_tokens = self.swin_cl(x_cl) x_cb_tokens = self.swin_cb(x_cb) x_cl_mean = x_cl_tokens.mean(dim=1) x_cb_mean = x_cb_tokens.mean(dim=1) x_cl = self.linear_cl(x_cl_mean) x_cl_tokens_transformed = self.linear_cl(x_cl_tokens) x_cb = self.linear_cl(x_cb_mean) x_cb_tokens_transformed = self.linear_cl(x_cb_tokens) return x_cl, x_cb, x_cl_tokens, x_cb_tokens class SwinModel_domain_agnostic(nn.Module): def __init__(self): super().__init__() self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0) self.swin_cb = self.swin_cl #timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0) self.linear_cl = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) self.linear_cb = nn.Linear(1024, 1024) self.classify = nn.Sequential(GradientReversal(alpha=0.6), # original 0.8 nn.Linear(1024,512), nn.ReLU(), nn.Linear(512,8)) def freeze_encoder(self): for param in self.swin_cl.parameters(): param.requires_grad = False for param in self.swin_cb.parameters(): param.requires_grad = False def unfreeze_encoder(self): for param in self.swin_cl.parameters(): param.requires_grad = True for param in self.swin_cb.parameters(): param.requires_grad = True def get_embeddings(self, image, ftype): linear = self.linear_cl if ftype == "contactless" else self.linear_cl swin = self.swin_cl if ftype == "contactless" else self.swin_cb tokens = swin(image) emb_mean = tokens.mean(dim=1) feat = linear(emb_mean) tokens_transformed = linear(tokens) return feat, tokens def forward(self, x_cl, x_cb): x_cl_tokens = self.swin_cl(x_cl) x_cb_tokens = self.swin_cb(x_cb) x_cl_mean = x_cl_tokens.mean(dim=1) x_cb_mean = x_cb_tokens.mean(dim=1) x_cl = self.linear_cl(x_cl_mean) x_cl_tokens_transformed = self.linear_cl(x_cl_tokens) x_cb = self.linear_cl(x_cb_mean) x_cb_tokens_transformed = self.linear_cl(x_cb_tokens) domain_class_cl = self.classify(x_cl_mean) domain_class_cb = self.classify(x_cb_mean) return x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb class SwinModel_Fusion(nn.Module): def __init__(self): super().__init__() self.feature_dim = 1024 self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0) self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.feature_dim, nhead=4, dropout=0.5, batch_first=True, norm_first=True, activation="gelu") self.fusion = nn.TransformerEncoder(self.encoder_layer, num_layers=2) self.sep_token = nn.Parameter(torch.randn(1, 1, self.feature_dim)) self.output_logit_mlp = nn.Sequential(nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(), nn.Linear(512, 1)) self.linear_cl = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 1024)) def load_pretrained_models(self, swin_cl_path, fusion_ckpt_path): swin_cl_state_dict = torch.load(swin_cl_path) new_dict = {} for key in swin_cl_state_dict.keys(): if "swin_cl" in key: new_dict[key.replace("swin_cl.","")] = swin_cl_state_dict[key] self.swin_cl.load_state_dict(new_dict) fusion_params = torch.load(fusion_ckpt_path) new_dict = {} for key in fusion_params.keys(): if "encoder_layer" in key: new_dict[key.replace("encoder_layer.","")] = fusion_params[key] self.encoder_layer.load_state_dict(new_dict) new_dict = {} for key in fusion_params.keys(): if "fusion" in key: new_dict[key.replace("fusion.","")] = fusion_params[key] self.fusion.load_state_dict(new_dict) self.sep_token = nn.Parameter(fusion_params["sep_token"]) new_dict = {} for key in fusion_params.keys(): if "output_logit_mlp" in key: new_dict[key.replace("output_logit_mlp.","")] = fusion_params[key] self.output_logit_mlp.load_state_dict(new_dict) def l2_norm(self,input): input_size = input.shape[0] buffer = torch.pow(input, 2) normp = torch.sum(buffer, 1).add_(1e-12) norm = torch.sqrt(normp) _output = torch.div(input, norm.view(-1, 1).expand_as(input)) return _output def combine_features(self, fingerprint_1_tokens, fingerprint_2_tokens): # This function takes a pair of embeddings [B, 49, 1024], [B, 49, 1024] and returns a B logit scores [B] # fingerprint_1_tokens = self.linear_cl(fingerprint_1_tokens) # fingerprint_2_tokens = self.linear_cl(fingerprint_2_tokens) batch_size = fingerprint_1_tokens.shape[0] sep_token = self.sep_token.repeat(batch_size, 1, 1) combine_features = torch.cat((fingerprint_1_tokens, sep_token, fingerprint_2_tokens), dim=1) fused_match_representation = self.fusion(combine_features) fingerprint_1 = fused_match_representation[:,:197,:].mean(dim=1) fingerprint_2 = fused_match_representation[:,198:,:].mean(dim=1) fingerprint_1_norm = self.l2_norm(fingerprint_1) fingerprint_2_norm = self.l2_norm(fingerprint_2) similarities = torch.sum(fingerprint_1_norm * fingerprint_2_norm, axis=1) differences = fingerprint_1 - fingerprint_2 squared_differences = differences ** 2 sum_squared_differences = torch.sum(squared_differences, axis=1) distances = torch.sqrt(sum_squared_differences) return similarities, distances def get_tokens(self, image, ftype): swin = self.swin_cl tokens = swin(image) return tokens def freeze_backbone(self): for param in self.swin_cl.parameters(): param.requires_grad = False def forward(self, x_cl, x_cb): x_cl_tokens = self.swin_cl(x_cl) x_cb_tokens = self.swin_cl(x_cb) return x_cl_tokens, x_cb_tokens