spandey8 commited on
Commit
007d3b9
·
verified ·
1 Parent(s): 3cae3e0

Upload 11 files

Browse files
combined_sampler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.utils.data.sampler import Sampler
5
+ from tqdm import *
6
+
7
+ class BalancedSampler(Sampler):
8
+ def __init__(self, data_source, batch_size, images_per_class=3):
9
+ self.data_source = data_source
10
+ self.ys = np.array(data_source.all_labels)
11
+ self.num_groups = batch_size // images_per_class
12
+ self.batch_size = batch_size
13
+ self.num_instances = images_per_class
14
+ self.num_samples = len(self.ys)
15
+ self.num_classes = len(set(self.ys))
16
+
17
+ def __len__(self):
18
+ return self.num_samples
19
+
20
+ def __iter__(self):
21
+ num_batches = len(self.data_source) // self.batch_size
22
+ ret = []
23
+ while num_batches > 0:
24
+ sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False)
25
+ for i in range(len(sampled_classes)):
26
+ ith_class_idxs = np.nonzero(self.ys == sampled_classes[i])[0]
27
+ class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True)
28
+ ret.extend(np.random.permutation(class_sel))
29
+ num_batches -= 1
30
+ return iter(ret)
hkpoly_evaluation_phase1.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # script to evaluated HKPolyU testing dataset on finetuned model after phase 1
2
+ import torch
3
+ from datasets.hkpoly_test import hktest
4
+ from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from model import SwinModel_domain_agnostic as Model
8
+ from sklearn.metrics import roc_curve, auc
9
+ import json
10
+
11
+ def calculate_tar_at_far(fpr, tpr, target_fars):
12
+ tar_at_far = {}
13
+ for far in target_fars:
14
+ if far in fpr:
15
+ tar = tpr[np.where(fpr == far)][0]
16
+ else:
17
+ tar = np.interp(far, fpr, tpr)
18
+ tar_at_far[far] = tar
19
+ return tar_at_far
20
+
21
+ if __name__ == '__main__':
22
+ device = torch.device('cuda')
23
+ data = hktest(split = 'test')
24
+ dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
25
+ model = Model().to(device)
26
+ checkpoint = torch.load("ridgeformer_checkpoints/phase1_ft_hkpoly.pt",map_location = torch.device('cpu'))
27
+ model.load_state_dict(checkpoint,strict=False)
28
+ model.eval()
29
+
30
+ cl_feats, cb_feats, cl_labels, cb_labels, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list()
31
+ with torch.no_grad():
32
+ for (x_cl, x_cb, label) in tqdm(dataloader):
33
+ x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
34
+ x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
35
+ x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
36
+ cl_feats_unnormed.append(x_cl_feat.cpu().detach().numpy())
37
+ cb_feats_unnormed.append(x_cb_feat.cpu().detach().numpy())
38
+ x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
39
+ x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
40
+ label = label.cpu().detach().numpy()
41
+ cl_feats.append(x_cl_feat)
42
+ cb_feats.append(x_cb_feat)
43
+ cl_labels.append(label)
44
+ cb_labels.append(label)
45
+
46
+ cl_feats = np.concatenate(cl_feats)
47
+ cb_feats = np.concatenate(cb_feats)
48
+ cl_feats_unnormed = np.concatenate(cl_feats_unnormed)
49
+ cb_feats_unnormed = np.concatenate(cb_feats_unnormed)
50
+ cl_label = torch.from_numpy(np.concatenate(cl_labels))
51
+ cb_label = torch.from_numpy(np.concatenate(cb_labels))
52
+
53
+ # CB2CL
54
+ squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
55
+ distance = -1 * np.sqrt(squared_diff)
56
+ similarities = np.dot(cl_feats,np.transpose(cb_feats))
57
+ scores_mat = similarities + 0.1 * distance
58
+
59
+ scores = scores_mat.flatten().tolist()
60
+ labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
61
+ ids_mod = list()
62
+ for i in labels:
63
+ if i==True:
64
+ ids_mod.append(1)
65
+ else:
66
+ ids_mod.append(0)
67
+
68
+ fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
69
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
70
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
71
+ tar_far_102 = tpr[upper_fpr_idx]
72
+ print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
73
+ print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
74
+
75
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
76
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
77
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
78
+ print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
79
+ print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
80
+
81
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
82
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
83
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
84
+ print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
85
+ print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
86
+
87
+ fnr = 1 - tpr
88
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
89
+ roc_auc = auc(fpr, tpr)
90
+ print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
91
+ print(f"EER for CB2CL: {EER * 100} %")
92
+ eer_cb2cl = EER * 100
93
+
94
+ cbcltf102 = tar_far_102 * 100
95
+ cbcltf103 = tar_far_103 * 100
96
+ cbcltf104 = tar_far_104 * 100
97
+ cl_label = cl_label.cpu().detach()
98
+ cb_label = cb_label.cpu().detach()
99
+ print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
100
+ print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
101
+ print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
102
+
103
+ print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 1) * 100} %")
104
+ print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 10) * 100} %")
105
+ print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 50) * 100} %")
106
+ print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 100) * 100} %")
hkpoly_evaluation_phase2.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # script to evaluate HKPolyU testing dataset on finetuned model after phase 2
2
+ import torch
3
+ from datasets.hkpoly_test import hktest
4
+ from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from model import SwinModel_Fusion as Model
8
+ from sklearn.metrics import roc_curve, auc
9
+ import json
10
+
11
+ def calculate_tar_at_far(fpr, tpr, target_fars):
12
+ tar_at_far = {}
13
+ for far in target_fars:
14
+ if far in fpr:
15
+ tar = tpr[np.where(fpr == far)][0]
16
+ else:
17
+ tar = np.interp(far, fpr, tpr)
18
+ tar_at_far[far] = tar
19
+ return tar_at_far
20
+
21
+ def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
22
+ cl_tokens = torch.cat(cl_tokens)
23
+ cb_tokens = torch.cat(cb_tokens)
24
+ batch_size = cl_tokens.shape[0]
25
+ shard_size = 20
26
+ similarity_matrix = torch.zeros((batch_size, batch_size))
27
+ for i_start in tqdm(range(0, batch_size, shard_size)):
28
+ i_end = min(i_start + shard_size, batch_size)
29
+ shard_i = cl_tokens[i_start:i_end]
30
+ for j_start in range(0, batch_size, shard_size):
31
+ j_end = min(j_start + shard_size, batch_size)
32
+ shard_j = cb_tokens[j_start:j_end]
33
+ batch_i = shard_i.unsqueeze(1)
34
+ batch_j = shard_j.unsqueeze(0)
35
+ pairwise_i = batch_i.expand(-1, shard_size, -1, -1)
36
+ pairwise_j = batch_j.expand(shard_size, -1, -1, -1)
37
+ similarity_scores, distances = model.combine_features(pairwise_i.reshape(-1, 197, 1024), pairwise_j.reshape(-1, 197, 1024))
38
+ scores = similarity_scores - 0.1 * distances
39
+ scores = scores.reshape(shard_size, shard_size)
40
+ similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
41
+ return similarity_matrix
42
+
43
+ if __name__ == '__main__':
44
+ device = torch.device('cuda')
45
+ data = hktest(split = 'test')
46
+ dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
47
+ model = Model().to(device)
48
+ checkpoint = torch.load("ridgeformer_checkpoints/phase2_ft_hkpoly.pt",map_location = torch.device('cpu'))
49
+ model.load_state_dict(checkpoint,strict=False)
50
+ model.eval()
51
+
52
+ cl_feats, cb_feats, cl_labels, cb_labels, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list()
53
+ with torch.no_grad():
54
+ for (x_cl, x_cb, label) in tqdm(dataloader):
55
+ x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
56
+ x_cl_token = model.get_tokens(x_cl,'contactless')
57
+ x_cb_token = model.get_tokens(x_cb,'contactbased')
58
+ label = label.cpu().detach().numpy()
59
+ cl_feats.append(x_cl_token)
60
+ cb_feats.append(x_cb_token)
61
+ cl_labels.append(label)
62
+ cb_labels.append(label)
63
+
64
+ cl_label = torch.from_numpy(np.concatenate(cl_labels))
65
+ cb_label = torch.from_numpy(np.concatenate(cb_labels))
66
+
67
+ # CB2CL
68
+ scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
69
+ scores = scores_mat.cpu().detach().numpy().flatten().tolist()
70
+ labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
71
+ ids_mod = list()
72
+ for i in labels:
73
+ if i==True:
74
+ ids_mod.append(1)
75
+ else:
76
+ ids_mod.append(0)
77
+
78
+ fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
79
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
80
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
81
+ tar_far_102 = tpr[upper_fpr_idx]
82
+
83
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
84
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
85
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
86
+
87
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
88
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
89
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
90
+
91
+ fnr = 1 - tpr
92
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
93
+ roc_auc = auc(fpr, tpr)
94
+ print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
95
+ print(f"EER for CB2CL: {EER * 100} %")
96
+ eer_cb2cl = EER * 100
97
+ cbcltf102 = tar_far_102 * 100
98
+ cbcltf103 = tar_far_103 * 100
99
+ cbcltf104 = tar_far_104 * 100
100
+ cl_label = cl_label.cpu().detach()
101
+ cb_label = cb_label.cpu().detach()
102
+
103
+ print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
104
+ print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
105
+ print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
106
+
107
+ recall_dict = dict()
108
+ for i in range(1,101):
109
+ recall_dict[i] = compute_recall_at_k(scores_mat, cl_label, cb_label, i)
110
+
111
+ print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
112
+ print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
113
+ print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
114
+ print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
loss.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_metric_learning import losses
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init
5
+ import torchvision.models as models
6
+ from torch.autograd import Variable
7
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8
+ from torch.nn.utils.weight_norm import weight_norm
9
+ import torch.backends.cudnn as cudnn
10
+ from torch.nn.utils.clip_grad import clip_grad_norm
11
+ import numpy as np
12
+ import os
13
+ import torch.nn.functional as F
14
+ import itertools
15
+
16
+ torch.autograd.set_detect_anomaly(True)
17
+
18
+ class DualMSLoss_FineGrained(nn.Module):
19
+ """
20
+ Compute contrastive loss
21
+ """
22
+ def __init__(self, margin=0, max_violation=False):
23
+ super(DualMSLoss_FineGrained, self).__init__()
24
+ self.margin = margin
25
+ self.max_violation = max_violation
26
+ self.thresh = 0.5
27
+ self.margin = 0.7 # 0.1
28
+ self.scale_pos = 2
29
+ self.scale_neg = 40.0
30
+
31
+ def ms_sample(self,sim_mat,label):
32
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
33
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
34
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
35
+ neg_mask = 1 - pos_mask
36
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
37
+ N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
38
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
39
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
40
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
41
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
42
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
43
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
44
+
45
+ return pos_loss + neg_loss
46
+
47
+ def ms_sample_cbcb_clcl(self,sim_mat,label):
48
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
49
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
50
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
51
+
52
+ pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
53
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
54
+ N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
55
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
56
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
57
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
58
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
59
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
60
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
61
+
62
+ return pos_loss + neg_loss
63
+
64
+ def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
65
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
66
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
67
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
68
+
69
+ n_sha = pos_mask.shape[0]
70
+ mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
71
+ mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
72
+ pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
73
+
74
+ neg_mask = 1-pos_mask
75
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
76
+ N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
77
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
78
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
79
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
80
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
81
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
82
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
83
+
84
+ return pos_loss + neg_loss
85
+
86
+ def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
87
+ B, T, D = tensor1.shape
88
+ average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
89
+
90
+ for start_idx in range(0, T, shard_size):
91
+ end_idx = min(start_idx + shard_size, T)
92
+
93
+ # Get the shard
94
+ shard_tensor1 = tensor1[:, start_idx:end_idx, :]
95
+ shard_tensor2 = tensor2[:, start_idx:end_idx, :]
96
+
97
+ # Reshape and expand
98
+ shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
99
+ shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
100
+
101
+ # Compute cosine similarity for the shard
102
+ shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
103
+
104
+ # Accumulate the sum of cosine similarities
105
+ average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
106
+
107
+ # Normalize by the total number of elements (T*T)
108
+ average_sim_matrix /= (T * T)
109
+
110
+ return average_sim_matrix
111
+
112
+ def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device):
113
+
114
+ sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
115
+ n = sim_mat_clcl.shape[0]
116
+ sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
117
+ sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
118
+
119
+ loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
120
+ loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
121
+
122
+ loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
123
+ return loss4 + loss2 + loss3#+ (1.5*loss2) + (1.5*loss3) # + loss2 + loss3#+ loss5 # 0.1*loss5 + loss3
124
+
125
+ def l2_norm(self, input):
126
+ input_size = input.size()
127
+ buffer = torch.pow(input, 2)
128
+ normp = torch.sum(buffer, 1).add_(1e-12)
129
+ norm = torch.sqrt(normp)
130
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
131
+ output = _output.view(input_size)
132
+
133
+ return output
134
+
135
+ class DualMSLoss_FineGrained_domain_agnostic(nn.Module):
136
+ """
137
+ Compute contrastive loss
138
+ """
139
+ def __init__(self, margin=0, max_violation=False):
140
+ super(DualMSLoss_FineGrained_domain_agnostic, self).__init__()
141
+ self.margin = margin
142
+ self.max_violation = max_violation
143
+ self.thresh = 0.5
144
+ self.margin = 0.5 # 0.1
145
+ self.scale_pos = 2
146
+ self.scale_neg = 40.0
147
+ self.criterion = nn.CrossEntropyLoss()
148
+
149
+ def ms_sample(self,sim_mat,label):
150
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
151
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
152
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
153
+ neg_mask = 1 - pos_mask
154
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
155
+ N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
156
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
157
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
158
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
159
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
160
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
161
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
162
+
163
+ return pos_loss + neg_loss
164
+
165
+ def ms_sample_cbcb_clcl(self,sim_mat,label):
166
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
167
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
168
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
169
+
170
+ pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
171
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
172
+ N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
173
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
174
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
175
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
176
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
177
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
178
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
179
+
180
+ return pos_loss + neg_loss
181
+
182
+ def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
183
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
184
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
185
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
186
+
187
+ n_sha = pos_mask.shape[0]
188
+ mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
189
+ mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
190
+ pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
191
+
192
+ neg_mask = 1-pos_mask
193
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
194
+ N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
195
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
196
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
197
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
198
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
199
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
200
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
201
+
202
+ return pos_loss + neg_loss
203
+
204
+ def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
205
+ B, T, D = tensor1.shape
206
+ average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
207
+
208
+ for start_idx in range(0, T, shard_size):
209
+ end_idx = min(start_idx + shard_size, T)
210
+
211
+ # Get the shard
212
+ shard_tensor1 = tensor1[:, start_idx:end_idx, :]
213
+ shard_tensor2 = tensor2[:, start_idx:end_idx, :]
214
+
215
+ # Reshape and expand
216
+ shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
217
+ shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
218
+
219
+ # Compute cosine similarity for the shard
220
+ shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
221
+
222
+ # Accumulate the sum of cosine similarities
223
+ average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
224
+
225
+ # Normalize by the total number of elements (T*T)
226
+ average_sim_matrix /= (T * T)
227
+
228
+ return average_sim_matrix
229
+
230
+ def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt):
231
+
232
+ sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
233
+ n = sim_mat_clcl.shape[0]
234
+
235
+ sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
236
+ sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
237
+
238
+ loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
239
+ loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
240
+
241
+ loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
242
+
243
+ pred = torch.cat([domain_class_cl, domain_class_cb])
244
+ gt = torch.cat([domain_class_cl_gt, domain_class_cb_gt])
245
+
246
+ domain_class_loss = self.criterion(pred,gt)
247
+ return loss4 + loss2 + loss3 + (3*domain_class_loss)
248
+
249
+ def l2_norm(self, input):
250
+ input_size = input.size()
251
+ buffer = torch.pow(input, 2)
252
+ normp = torch.sum(buffer, 1).add_(1e-12)
253
+ norm = torch.sqrt(normp)
254
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
255
+ output = _output.view(input_size)
256
+
257
+ return output
258
+
259
+ class DualMSLoss_FineGrained_domain_agnostic_ft(nn.Module):
260
+ """
261
+ Compute contrastive loss
262
+ """
263
+ def __init__(self, margin=0, max_violation=False):
264
+ super(DualMSLoss_FineGrained_domain_agnostic_ft, self).__init__()
265
+ self.margin = margin
266
+ self.max_violation = max_violation
267
+ self.thresh = 0.5
268
+ self.margin = 0.7 # 0.1
269
+ self.scale_pos = 2
270
+ self.scale_neg = 40.0
271
+ self.criterion = nn.CrossEntropyLoss()
272
+
273
+ def ms_sample(self,sim_mat,label):
274
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
275
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
276
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
277
+ neg_mask = 1 - pos_mask
278
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
279
+ N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
280
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
281
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
282
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
283
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
284
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
285
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
286
+
287
+ return pos_loss + neg_loss
288
+
289
+ def ms_sample_cbcb_clcl(self,sim_mat,label):
290
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
291
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
292
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
293
+
294
+ pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
295
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
296
+ N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
297
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
298
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
299
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
300
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
301
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
302
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
303
+
304
+ return pos_loss + neg_loss
305
+
306
+ def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
307
+ pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
308
+ neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
309
+ pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
310
+
311
+ n_sha = pos_mask.shape[0]
312
+ mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
313
+ mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
314
+ pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
315
+
316
+ neg_mask = 1-pos_mask
317
+ P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
318
+ N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
319
+ min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
320
+ max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
321
+ hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
322
+ hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
323
+ pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
324
+ neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
325
+
326
+ return pos_loss + neg_loss
327
+
328
+ def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
329
+ B, T, D = tensor1.shape
330
+ average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
331
+
332
+ for start_idx in range(0, T, shard_size):
333
+ end_idx = min(start_idx + shard_size, T)
334
+
335
+ # Get the shard
336
+ shard_tensor1 = tensor1[:, start_idx:end_idx, :]
337
+ shard_tensor2 = tensor2[:, start_idx:end_idx, :]
338
+
339
+ # Reshape and expand
340
+ shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
341
+ shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
342
+
343
+ # Compute cosine similarity for the shard
344
+ shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
345
+
346
+ # Accumulate the sum of cosine similarities
347
+ average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
348
+
349
+ # Normalize by the total number of elements (T*T)
350
+ average_sim_matrix /= (T * T)
351
+
352
+ return average_sim_matrix
353
+
354
+ def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt):
355
+
356
+ sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
357
+ n = sim_mat_clcl.shape[0]
358
+
359
+ sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
360
+ sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
361
+
362
+ loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
363
+ loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
364
+
365
+ loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
366
+
367
+ return loss4 + loss2 + loss3
368
+
369
+ def l2_norm(self, input):
370
+ input_size = input.size()
371
+ buffer = torch.pow(input, 2)
372
+ normp = torch.sum(buffer, 1).add_(1e-12)
373
+ norm = torch.sqrt(normp)
374
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
375
+ output = _output.view(input_size)
376
+
377
+ return output
model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.optim as optim
7
+ from torchvision import datasets, transforms
8
+ from torch.optim.lr_scheduler import StepLR
9
+ import torchvision.models as models
10
+ import timm
11
+ from pprint import pprint
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+ from torch.utils.data.sampler import BatchSampler
15
+ from gradient_reversal.module import GradientReversal
16
+
17
+ class SwinModel(nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
21
+ self.swin_cb = self.swin_cl
22
+
23
+ self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
24
+ nn.ReLU(),
25
+ nn.Linear(1024, 1024))
26
+ self.linear_cb = nn.Linear(1024, 1024)
27
+
28
+ def freeze_encoder(self):
29
+ for param in self.swin_cl.parameters():
30
+ param.requires_grad = False
31
+ for param in self.swin_cb.parameters():
32
+ param.requires_grad = False
33
+
34
+ def unfreeze_encoder(self):
35
+ for param in self.swin_cl.parameters():
36
+ param.requires_grad = True
37
+ for param in self.swin_cb.parameters():
38
+ param.requires_grad = True
39
+
40
+ def get_embeddings(self, image, ftype):
41
+ linear = self.linear_cl if ftype == "contactless" else self.linear_cl
42
+ swin = self.swin_cl if ftype == "contactless" else self.swin_cb
43
+
44
+ tokens = swin(image)
45
+ emb_mean = tokens.mean(dim=1)
46
+ feat = linear(emb_mean)
47
+ tokens_transformed = linear(tokens)
48
+ return feat, tokens
49
+
50
+ def forward(self, x_cl, x_cb):
51
+ x_cl_tokens = self.swin_cl(x_cl)
52
+ x_cb_tokens = self.swin_cb(x_cb)
53
+
54
+ x_cl_mean = x_cl_tokens.mean(dim=1)
55
+ x_cb_mean = x_cb_tokens.mean(dim=1)
56
+
57
+ x_cl = self.linear_cl(x_cl_mean)
58
+ x_cl_tokens_transformed = self.linear_cl(x_cl_tokens)
59
+
60
+ x_cb = self.linear_cl(x_cb_mean)
61
+ x_cb_tokens_transformed = self.linear_cl(x_cb_tokens)
62
+
63
+ return x_cl, x_cb, x_cl_tokens, x_cb_tokens
64
+
65
+ class SwinModel_domain_agnostic(nn.Module):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
69
+ self.swin_cb = self.swin_cl #timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
70
+
71
+ self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
72
+ nn.ReLU(),
73
+ nn.Linear(1024, 1024))
74
+ self.linear_cb = nn.Linear(1024, 1024)
75
+ self.classify = nn.Sequential(GradientReversal(alpha=0.6), # original 0.8
76
+ nn.Linear(1024,512),
77
+ nn.ReLU(),
78
+ nn.Linear(512,8))
79
+
80
+ def freeze_encoder(self):
81
+ for param in self.swin_cl.parameters():
82
+ param.requires_grad = False
83
+ for param in self.swin_cb.parameters():
84
+ param.requires_grad = False
85
+
86
+ def unfreeze_encoder(self):
87
+ for param in self.swin_cl.parameters():
88
+ param.requires_grad = True
89
+ for param in self.swin_cb.parameters():
90
+ param.requires_grad = True
91
+
92
+ def get_embeddings(self, image, ftype):
93
+ linear = self.linear_cl if ftype == "contactless" else self.linear_cl
94
+ swin = self.swin_cl if ftype == "contactless" else self.swin_cb
95
+
96
+ tokens = swin(image)
97
+ emb_mean = tokens.mean(dim=1)
98
+ feat = linear(emb_mean)
99
+ tokens_transformed = linear(tokens)
100
+ return feat, tokens
101
+
102
+ def forward(self, x_cl, x_cb):
103
+ x_cl_tokens = self.swin_cl(x_cl)
104
+ x_cb_tokens = self.swin_cb(x_cb)
105
+
106
+ x_cl_mean = x_cl_tokens.mean(dim=1)
107
+ x_cb_mean = x_cb_tokens.mean(dim=1)
108
+
109
+ x_cl = self.linear_cl(x_cl_mean)
110
+ x_cl_tokens_transformed = self.linear_cl(x_cl_tokens)
111
+
112
+ x_cb = self.linear_cl(x_cb_mean)
113
+ x_cb_tokens_transformed = self.linear_cl(x_cb_tokens)
114
+
115
+ domain_class_cl = self.classify(x_cl_mean)
116
+ domain_class_cb = self.classify(x_cb_mean)
117
+
118
+ return x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb
119
+
120
+ class SwinModel_Fusion(nn.Module):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.feature_dim = 1024
124
+ self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
125
+ self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.feature_dim, nhead=4, dropout=0.5, batch_first=True, norm_first=True, activation="gelu")
126
+ self.fusion = nn.TransformerEncoder(self.encoder_layer, num_layers=2)
127
+ self.sep_token = nn.Parameter(torch.randn(1, 1, self.feature_dim))
128
+ self.output_logit_mlp = nn.Sequential(nn.Linear(1024, 512),
129
+ nn.ReLU(),
130
+ nn.Dropout(),
131
+ nn.Linear(512, 1))
132
+ self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
133
+ nn.ReLU(),
134
+ nn.Linear(1024, 1024))
135
+
136
+ def load_pretrained_models(self, swin_cl_path, fusion_ckpt_path):
137
+ swin_cl_state_dict = torch.load(swin_cl_path)
138
+ new_dict = {}
139
+ for key in swin_cl_state_dict.keys():
140
+ if "swin_cl" in key:
141
+ new_dict[key.replace("swin_cl.","")] = swin_cl_state_dict[key]
142
+ self.swin_cl.load_state_dict(new_dict)
143
+
144
+ fusion_params = torch.load(fusion_ckpt_path)
145
+ new_dict = {}
146
+ for key in fusion_params.keys():
147
+ if "encoder_layer" in key:
148
+ new_dict[key.replace("encoder_layer.","")] = fusion_params[key]
149
+ self.encoder_layer.load_state_dict(new_dict)
150
+
151
+ new_dict = {}
152
+ for key in fusion_params.keys():
153
+ if "fusion" in key:
154
+ new_dict[key.replace("fusion.","")] = fusion_params[key]
155
+ self.fusion.load_state_dict(new_dict)
156
+
157
+ self.sep_token = nn.Parameter(fusion_params["sep_token"])
158
+
159
+ new_dict = {}
160
+ for key in fusion_params.keys():
161
+ if "output_logit_mlp" in key:
162
+ new_dict[key.replace("output_logit_mlp.","")] = fusion_params[key]
163
+ self.output_logit_mlp.load_state_dict(new_dict)
164
+
165
+ def l2_norm(self,input):
166
+ input_size = input.shape[0]
167
+ buffer = torch.pow(input, 2)
168
+ normp = torch.sum(buffer, 1).add_(1e-12)
169
+ norm = torch.sqrt(normp)
170
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
171
+ return _output
172
+
173
+ def combine_features(self, fingerprint_1_tokens, fingerprint_2_tokens):
174
+ # This function takes a pair of embeddings [B, 49, 1024], [B, 49, 1024] and returns a B logit scores [B]
175
+ # fingerprint_1_tokens = self.linear_cl(fingerprint_1_tokens)
176
+ # fingerprint_2_tokens = self.linear_cl(fingerprint_2_tokens)
177
+ batch_size = fingerprint_1_tokens.shape[0]
178
+ sep_token = self.sep_token.repeat(batch_size, 1, 1)
179
+ combine_features = torch.cat((fingerprint_1_tokens, sep_token, fingerprint_2_tokens), dim=1)
180
+ fused_match_representation = self.fusion(combine_features)
181
+ fingerprint_1 = fused_match_representation[:,:197,:].mean(dim=1)
182
+ fingerprint_2 = fused_match_representation[:,198:,:].mean(dim=1)
183
+
184
+ fingerprint_1_norm = self.l2_norm(fingerprint_1)
185
+ fingerprint_2_norm = self.l2_norm(fingerprint_2)
186
+
187
+ similarities = torch.sum(fingerprint_1_norm * fingerprint_2_norm, axis=1)
188
+
189
+ differences = fingerprint_1 - fingerprint_2
190
+ squared_differences = differences ** 2
191
+ sum_squared_differences = torch.sum(squared_differences, axis=1)
192
+ distances = torch.sqrt(sum_squared_differences)
193
+ return similarities, distances
194
+
195
+ def get_tokens(self, image, ftype):
196
+ swin = self.swin_cl
197
+ tokens = swin(image)
198
+ return tokens
199
+
200
+ def freeze_backbone(self):
201
+ for param in self.swin_cl.parameters():
202
+ param.requires_grad = False
203
+
204
+ def forward(self, x_cl, x_cb):
205
+ x_cl_tokens = self.swin_cl(x_cl)
206
+ x_cb_tokens = self.swin_cl(x_cb)
207
+ return x_cl_tokens, x_cb_tokens
rb_evaluation_phase1.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets.rb_loader import RB_loader
3
+ from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from model import SwinModel_domain_agnostic as Model
7
+ from sklearn.metrics import roc_curve, auc
8
+ import json
9
+ import torch.nn.functional as F
10
+
11
+ if __name__ == '__main__':
12
+ device = torch.device('cuda')
13
+ data = RB_loader(split = 'test')
14
+ dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
15
+ model = Model().to(device)
16
+ checkpoint = torch.load("ridgeformer_checkpoints/phase1_scratch.pt",map_location = torch.device('cpu'))
17
+ model.load_state_dict(checkpoint,strict=False)
18
+
19
+ model.eval()
20
+ cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
21
+ print("Computing Test Recall")
22
+ with torch.no_grad():
23
+ for (x_cl, x_cb, target, cl_fname, cb_fname) in tqdm(dataloader):
24
+ x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
25
+ x_cl, _ = model.get_embeddings(x_cl, ftype="contactless")
26
+ x_cb, _ = model.get_embeddings(x_cb, ftype="contactbased")
27
+ cl_feats_unnormed.append(x_cl.cpu().detach().numpy())
28
+ cb_feats_unnormed.append(x_cb.cpu().detach().numpy())
29
+ x_cl = l2_norm(x_cl).cpu().detach().numpy()
30
+ x_cb = l2_norm(x_cb).cpu().detach().numpy()
31
+ target = target.cpu().detach().numpy()
32
+ cl_feats.append(x_cl)
33
+ cb_feats.append(x_cb)
34
+ cl_labels.append(target)
35
+ cb_labels.append(target)
36
+ cl_fnames.extend(cl_fname)
37
+ cb_fnames.extend(cb_fname)
38
+
39
+ cl_feats = torch.from_numpy(np.concatenate(cl_feats))
40
+ cb_feats = torch.from_numpy(np.concatenate(cb_feats))
41
+ cl_labels = torch.from_numpy(np.concatenate(cl_labels))
42
+ cb_labels = torch.from_numpy(np.concatenate(cb_labels))
43
+ cl_feats_unnormed = torch.from_numpy(np.concatenate(cl_feats_unnormed))
44
+ cb_feats_unnormed = torch.from_numpy(np.concatenate(cb_feats_unnormed))
45
+
46
+ unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
47
+ unique_feats = torch.stack([cb_feats[indices == i].mean(dim=0) for i in range(len(unique_labels))])
48
+ cb_feats = unique_feats
49
+ unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
50
+ unique_feats = torch.stack([cb_feats_unnormed[indices == i].mean(dim=0) for i in range(len(unique_labels))])
51
+ cb_labels = unique_labels
52
+ cb_feats_unnormed = unique_feats
53
+
54
+ # CL2CB <---------------------------------------->
55
+ cl_feats = cl_feats.numpy()
56
+ cb_feats = cb_feats.numpy()
57
+ cb_feats_unnormed = cb_feats_unnormed.numpy()
58
+ cl_feats_unnormed = cl_feats_unnormed.numpy()
59
+
60
+ squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
61
+ distance = -1 * np.sqrt(squared_diff)
62
+ similarities = np.dot(cl_feats,np.transpose(cb_feats))
63
+ scores_mat = similarities + 0.1 * distance
64
+ scores = scores_mat.flatten().tolist()
65
+
66
+ ids = torch.eq(cl_labels.view(-1,1)-cb_labels.view(1,-1),0.0).flatten().tolist()
67
+ ids_mod = list()
68
+ for x in ids:
69
+ if x==True:
70
+ ids_mod.append(1)
71
+ else:
72
+ ids_mod.append(0)
73
+ fpr,tpr,thresh = roc_curve(ids_mod,scores,drop_intermediate=True)
74
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
75
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
76
+ tar_far_102 = tpr[upper_fpr_idx]
77
+ print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
78
+ print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
79
+
80
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
81
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
82
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
83
+ print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
84
+ print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
85
+
86
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
87
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
88
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
89
+ print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
90
+ print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
91
+
92
+ fnr = 1 - tpr
93
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
94
+ roc_auc = auc(fpr, tpr)
95
+ print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
96
+ print(f"EER for CB2CL: {EER * 100} %")
97
+ eer_cb2cl = EER * 100
98
+ cbcltf102 = tar_far_102 * 100
99
+ cbcltf103 = tar_far_103 * 100
100
+ cbcltf104 = tar_far_104 * 100
101
+ cl_labels = cl_labels.cpu().detach()
102
+ cb_labels = cb_labels.cpu().detach()
103
+ print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
104
+ print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
105
+ print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
106
+ print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 1) * 100} %")
107
+ print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 10) * 100} %")
108
+ print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 50) * 100} %")
109
+ print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 100) * 100} %")
110
+
111
+ ################################################################################
112
+
113
+ # CL2CL
114
+ scores = torch.from_numpy(np.dot(cl_feats,np.transpose(cl_feats)))
115
+ row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
116
+ scores = scores[row, col]
117
+ scores = scores.numpy().flatten().tolist()
118
+ labels = torch.eq(cl_labels.view(-1,1) - cl_labels.view(1,-1),0.0).float().cuda()
119
+ labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1].tolist()
120
+ fpr,tpr,_ = roc_curve(labels,scores)
121
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
122
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
123
+ tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
124
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
125
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
126
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
127
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
128
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
129
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
130
+ clcltf102 = tar_far_102 * 100
131
+ clcltf103 = tar_far_103 * 100
132
+ clcltf104 = tar_far_104 * 100
133
+ fnr = 1 - tpr
134
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
135
+ roc_auc = auc(fpr, tpr)
136
+ print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
137
+ print(f"EER for CL2CL: {EER * 100} %")
138
+ eer_cl2cl = EER * 100
139
+ print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
140
+ print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
141
+ print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
142
+ cl_labels = cl_labels.cpu().detach().numpy()
143
+ recall_score = Prev_RetMetric([cl_feats,cl_feats],[cl_labels,cl_labels],cl2cl = True)
144
+ cl2clk1 = recall_score.recall_k(k=1) * 100
145
+ print(f"R@1 for CL2CL: {recall_score.recall_k(k=1) * 100} %")
146
+ print(f"R@10 for CL2CL: {recall_score.recall_k(k=10) * 100} %")
147
+ print(f"R@50 for CL2CL: {recall_score.recall_k(k=50) * 100} %")
148
+ print(f"R@100 for CL2CL: {recall_score.recall_k(k=100) * 100} %")
rb_evaluation_phase2.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets.rb_loader_cl import RB_loader_cl
3
+ from datasets.rb_loader_cb import RB_loader_cb
4
+ from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from model import SwinModel_Fusion as Model
8
+ from sklearn.metrics import roc_curve, auc
9
+ import json
10
+ import torch.nn.functional as F
11
+
12
+ def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
13
+ cl_tokens = torch.cat(cl_tokens)
14
+ cb_tokens = torch.cat(cb_tokens)
15
+
16
+ batch_size_cl = cl_tokens.shape[0]
17
+ batch_size_cb = cb_tokens.shape[0]
18
+ shard_size = 20
19
+ similarity_matrix = torch.zeros((batch_size_cl, batch_size_cb))
20
+ for i_start in tqdm(range(0, batch_size_cl, shard_size)):
21
+ i_end = min(i_start + shard_size, batch_size_cl)
22
+ shard_i = cl_tokens[i_start:i_end]
23
+ for j_start in range(0, batch_size_cb, shard_size):
24
+ j_end = min(j_start + shard_size, batch_size_cb)
25
+ shard_j = cb_tokens[j_start:j_end]
26
+ batch_i = shard_i.unsqueeze(1)
27
+ batch_j = shard_j.unsqueeze(0)
28
+
29
+ pairwise_i = batch_i.expand(-1, shard_j.shape[0], -1, -1)
30
+ pairwise_j = batch_j.expand(shard_i.shape[0], -1, -1, -1)
31
+
32
+ similarity_scores, distances = model.combine_features(
33
+ pairwise_i.reshape(-1, 197, shard_i.shape[-1]),
34
+ pairwise_j.reshape(-1, 197, shard_j.shape[-1])
35
+ )
36
+ scores = similarity_scores - 0.1 * distances #-0.1
37
+ scores = scores.reshape(shard_i.shape[0], shard_j.shape[0])
38
+ similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
39
+ return similarity_matrix
40
+
41
+ device = torch.device('cuda')
42
+ data_cl = RB_loader_cl(split="test")
43
+ data_cb = RB_loader_cb(split="test")
44
+ dataloader_cb = torch.utils.data.DataLoader(data_cb,batch_size = 16, num_workers = 1, pin_memory = True)
45
+ dataloader_cl = torch.utils.data.DataLoader(data_cl,batch_size = 16, num_workers = 1, pin_memory = True)
46
+ model = Model().to(device)
47
+ checkpoint = torch.load("ridgeformer_checkpoints/phase2_scratch.pt",map_location = torch.device('cpu'))
48
+ model.load_state_dict(checkpoint,strict=False)
49
+
50
+ model.eval()
51
+ cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
52
+ print("Computing Test Recall")
53
+
54
+ with torch.no_grad():
55
+ for (x_cb, target) in tqdm(dataloader_cb):
56
+ x_cb, label = x_cb.to(device), target.to(device)
57
+ x_cb_token = model.get_tokens(x_cb,'contactbased')
58
+ label = label.cpu().detach().numpy()
59
+ cb_feats.append(x_cb_token)
60
+ cb_labels.append(label)
61
+
62
+ with torch.no_grad():
63
+ for (x_cl, target) in tqdm(dataloader_cl):
64
+ x_cl, label = x_cl.to(device), target.to(device)
65
+ x_cl_token = model.get_tokens(x_cl,'contactless')
66
+ label = label.cpu().detach().numpy()
67
+ cl_feats.append(x_cl_token)
68
+ cl_labels.append(label)
69
+
70
+ cl_label = torch.from_numpy(np.concatenate(cl_labels))
71
+ cb_label = torch.from_numpy(np.concatenate(cb_labels))
72
+
73
+ # CB2CL <---------------------------------------->
74
+ scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
75
+ scores = scores_mat.cpu().detach().numpy().flatten().tolist()
76
+ labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
77
+ ids_mod = list()
78
+ for i in labels:
79
+ if i==True:
80
+ ids_mod.append(1)
81
+ else:
82
+ ids_mod.append(0)
83
+ fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
84
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
85
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
86
+ tar_far_102 = tpr[upper_fpr_idx]#(tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
87
+
88
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
89
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
90
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
91
+
92
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
93
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
94
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
95
+
96
+ fnr = 1 - tpr
97
+
98
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
99
+ roc_auc = auc(fpr, tpr)
100
+ print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
101
+ print(f"EER for CB2CL: {EER * 100} %")
102
+ eer_cb2cl = EER * 100
103
+ cbcltf102 = tar_far_102 * 100
104
+ cbcltf103 = tar_far_103 * 100
105
+ cbcltf104 = tar_far_104 * 100
106
+ cl_label = cl_label.cpu().detach()
107
+ cb_label = cb_label.cpu().detach()
108
+ print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
109
+ print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
110
+ print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
111
+
112
+ print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
113
+ print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
114
+ print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
115
+ print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
116
+
117
+ # CL2CL -------------------------
118
+ scores = get_fused_cross_score_matrix(model, cl_feats, cl_feats)
119
+ scores_mat = scores
120
+ row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
121
+ scores = scores[row, col]
122
+ labels = torch.eq(cl_label.view(-1,1) - cl_label.view(1,-1),0.0).float().cuda()
123
+ labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1]
124
+ scores = scores.cpu().detach().numpy().flatten().tolist()
125
+ labels = labels.flatten().tolist()
126
+ ids_mod = list()
127
+ for i in labels:
128
+ if i==True:
129
+ ids_mod.append(1)
130
+ else:
131
+ ids_mod.append(0)
132
+ fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
133
+
134
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
135
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
136
+ tar_far_102 = tpr[upper_fpr_idx]#(tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
137
+
138
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
139
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
140
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
141
+
142
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
143
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
144
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
145
+
146
+ fnr = 1 - tpr
147
+
148
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
149
+ roc_auc = auc(fpr, tpr)
150
+ print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
151
+ print(f"EER for CL2CL: {EER * 100} %")
152
+ eer_cb2cl = EER * 100
153
+ cbcltf102 = tar_far_102 * 100
154
+ cbcltf103 = tar_far_103 * 100
155
+ cbcltf104 = tar_far_104 * 100
156
+ cl_label = cl_label.cpu().detach()
157
+ print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
158
+ print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
159
+ print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
160
+
161
+ print(f"R@1 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 1) * 100} %")
162
+ print(f"R@10 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 10) * 100} %")
163
+ print(f"R@50 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 50) * 100} %")
164
+ print(f"R@100 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 100) * 100} %")
requirements.txt ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=main
5
+ _openmp_mutex=5.1=1_gnu
6
+ absl-py=2.1.0=pypi_0
7
+ addict=2.4.0=pypi_0
8
+ aliyun-python-sdk-core=2.15.0=pypi_0
9
+ aliyun-python-sdk-kms=2.16.2=pypi_0
10
+ attrs=23.2.0=pypi_0
11
+ blas=1.0=mkl
12
+ bzip2=1.0.8=h5eee18b_6
13
+ ca-certificates=2024.7.2=h06a4308_0
14
+ cachetools=5.4.0=pypi_0
15
+ certifi=2024.2.2=pypi_0
16
+ cffi=1.16.0=pypi_0
17
+ charset-normalizer=2.1.1=pypi_0
18
+ click=8.1.7=pypi_0
19
+ colorama=0.4.6=pypi_0
20
+ coloredlogs=15.0.1=pypi_0
21
+ contourpy=1.1.1=pypi_0
22
+ crcmod=1.7=pypi_0
23
+ cryptography=42.0.5=pypi_0
24
+ cuda-cudart=11.8.89=0
25
+ cuda-cudart_linux-64=12.4.127=hd681fbe_0
26
+ cuda-cupti=11.8.87=0
27
+ cuda-libraries=11.8.0=0
28
+ cuda-nvrtc=11.8.89=0
29
+ cuda-nvtx=11.8.86=0
30
+ cuda-opencl=12.4.127=h6a678d5_0
31
+ cuda-runtime=11.8.0=0
32
+ cuda-version=12.4=hbda6634_3
33
+ cycler=0.12.1=pypi_0
34
+ entrypoints=0.4=pypi_0
35
+ ffmpeg=4.3=hf484d3e_0
36
+ flatbuffers=24.3.25=pypi_0
37
+ fonttools=4.53.1=pypi_0
38
+ freetype=2.12.1=h4a9f257_0
39
+ fsspec=2024.3.1=pypi_0
40
+ gmp=6.2.1=h295c915_3
41
+ gnutls=3.6.15=he1e5248_0
42
+ google-auth=2.33.0=pypi_0
43
+ google-auth-oauthlib=1.0.0=pypi_0
44
+ grpcio=1.65.4=pypi_0
45
+ httpcore=1.0.5=pypi_0
46
+ httpx=0.27.0=pypi_0
47
+ huggingface-hub=0.22.1=pypi_0
48
+ humanfriendly=10.0=pypi_0
49
+ idna=3.6=pypi_0
50
+ imageio=2.34.2=pypi_0
51
+ importlib-metadata=7.1.0=pypi_0
52
+ importlib-resources=6.4.0=pypi_0
53
+ intel-openmp=2023.1.0=hdb19cb5_46306
54
+ jinja2=3.1.3=pypi_0
55
+ jmespath=0.10.0=pypi_0
56
+ joblib=1.4.2=pypi_0
57
+ jpeg=9e=h5eee18b_2
58
+ jsonschema=4.21.1=pypi_0
59
+ jsonschema-specifications=2023.12.1=pypi_0
60
+ kaleido=0.2.1=pypi_0
61
+ kiwisolver=1.4.5=pypi_0
62
+ lame=3.100=h7b6447c_0
63
+ lcms2=2.12=h3be6417_0
64
+ ld_impl_linux-64=2.38=h1181459_1
65
+ lerc=3.0=h295c915_0
66
+ libcublas=11.11.3.6=0
67
+ libcufft=10.9.0.58=0
68
+ libcufile=1.9.1.3=h99ab3db_1
69
+ libcurand=10.3.5.147=h99ab3db_1
70
+ libcusolver=11.4.1.48=0
71
+ libcusparse=11.7.5.86=0
72
+ libdeflate=1.17=h5eee18b_1
73
+ libffi=3.4.4=h6a678d5_1
74
+ libgcc-ng=11.2.0=h1234567_1
75
+ libgomp=11.2.0=h1234567_1
76
+ libiconv=1.16=h5eee18b_3
77
+ libidn2=2.3.4=h5eee18b_0
78
+ libjpeg-turbo=2.0.0=h9bf148f_0
79
+ libnpp=11.8.0.86=0
80
+ libnvfatbin=12.4.127=h7934f7d_2
81
+ libnvjitlink=12.4.99=0
82
+ libnvjpeg=11.9.0.86=0
83
+ libpng=1.6.39=h5eee18b_0
84
+ libstdcxx-ng=11.2.0=h1234567_1
85
+ libtasn1=4.19.0=h5eee18b_0
86
+ libtiff=4.5.1=h6a678d5_0
87
+ libunistring=0.9.10=h27cfd23_0
88
+ libwebp-base=1.3.2=h5eee18b_0
89
+ llvm-openmp=14.0.6=h9e868ea_0
90
+ llvmlite=0.41.1=pypi_0
91
+ lz4-c=1.9.4=h6a678d5_1
92
+ markdown=3.6=pypi_0
93
+ markdown-it-py=3.0.0=pypi_0
94
+ markupsafe=2.1.5=pypi_0
95
+ matplotlib=3.7.5=pypi_0
96
+ mdit-py-plugins=0.4.0=pypi_0
97
+ mkl=2023.1.0=h213fc3f_46344
98
+ mmcv=2.1.0=dev_0
99
+ mmdet=3.3.0=dev_0
100
+ mmengine=0.10.3=pypi_0
101
+ model-index=0.1.11=pypi_0
102
+ mpc=1.1.0=h10f8cd9_1
103
+ mpfr=4.0.2=hb69a4c5_1
104
+ mpmath=1.3.0=py38h06a4308_0
105
+ ncurses=6.4=h6a678d5_0
106
+ nettle=3.7.3=hbbd107a_1
107
+ networkx=3.1=py38h06a4308_0
108
+ numba=0.58.1=pypi_0
109
+ numpy=1.24.4=pypi_0
110
+ nvidia-cublas-cu11=11.11.3.6=pypi_0
111
+ nvidia-cuda-cupti-cu11=11.8.87=pypi_0
112
+ nvidia-cuda-nvrtc-cu11=11.8.89=pypi_0
113
+ nvidia-cuda-runtime-cu11=11.8.89=pypi_0
114
+ nvidia-cudnn-cu11=8.7.0.84=pypi_0
115
+ nvidia-cufft-cu11=10.9.0.58=pypi_0
116
+ nvidia-curand-cu11=10.3.0.86=pypi_0
117
+ nvidia-cusolver-cu11=11.4.1.48=pypi_0
118
+ nvidia-cusparse-cu11=11.7.5.86=pypi_0
119
+ nvidia-nccl-cu11=2.19.3=pypi_0
120
+ nvidia-nvtx-cu11=11.8.86=pypi_0
121
+ oauthlib=3.2.2=pypi_0
122
+ ocl-icd=2.3.2=h5eee18b_1
123
+ onnxruntime=1.18.1=pypi_0
124
+ opencv-python=4.10.0.84=pypi_0
125
+ opencv-python-headless=4.10.0.84=pypi_0
126
+ opendatalab=0.0.10=pypi_0
127
+ openh264=2.1.1=h4ff587b_0
128
+ openjpeg=2.4.0=h9ca470c_2
129
+ openmim=0.3.9=pypi_0
130
+ openssl=3.0.14=h5eee18b_0
131
+ openxlab=0.0.37=pypi_0
132
+ ordered-set=4.1.0=pypi_0
133
+ oss2=2.17.0=pypi_0
134
+ packaging=24.0=pypi_0
135
+ pandas=2.0.3=pypi_0
136
+ pillow=9.0.1=pypi_0
137
+ pip=23.3.1=pypi_0
138
+ pkgutil-resolve-name=1.3.10=pypi_0
139
+ platformdirs=4.2.0=pypi_0
140
+ plotly=5.23.0=pypi_0
141
+ pooch=1.8.2=pypi_0
142
+ protobuf=5.27.3=pypi_0
143
+ pyasn1=0.6.0=pypi_0
144
+ pyasn1-modules=0.4.0=pypi_0
145
+ pycocotools=2.0.7=pypi_0
146
+ pycparser=2.21=pypi_0
147
+ pygments=2.17.2=pypi_0
148
+ pymatting=1.1.12=pypi_0
149
+ pyparsing=3.1.2=pypi_0
150
+ python=3.8.19=h955ad1f_0
151
+ python-dateutil=2.9.0.post0=pypi_0
152
+ pytorch-cuda=11.8=h7e8668a_5
153
+ pytorch-metric-learning=2.5.0=pypi_0
154
+ pytorch-mutex=1.0=cuda
155
+ pytz=2023.4=pypi_0
156
+ pywavelets=1.4.1=pypi_0
157
+ pyyaml=6.0.1=py38h5eee18b_0
158
+ readline=8.2=h5eee18b_0
159
+ referencing=0.34.0=pypi_0
160
+ rembg=2.0.58=pypi_0
161
+ requests=2.28.2=pypi_0
162
+ requests-oauthlib=2.0.0=pypi_0
163
+ rich=13.4.2=pypi_0
164
+ rpds-py=0.18.0=pypi_0
165
+ rsa=4.9=pypi_0
166
+ safetensors=0.4.2=pypi_0
167
+ scikit-image=0.19.3=pypi_0
168
+ scikit-learn=1.3.2=pypi_0
169
+ scipy=1.10.1=pypi_0
170
+ setuptools=60.2.0=pypi_0
171
+ shapely=2.0.3=pypi_0
172
+ six=1.16.0=pypi_0
173
+ sqlite=3.45.3=h5eee18b_0
174
+ sympy=1.12=py38h06a4308_0
175
+ tabulate=0.9.0=pypi_0
176
+ tbb=2021.8.0=hdb19cb5_0
177
+ tenacity=9.0.0=pypi_0
178
+ tensorboard=2.14.0=pypi_0
179
+ tensorboard-data-server=0.7.2=pypi_0
180
+ termcolor=2.4.0=pypi_0
181
+ terminaltables=3.1.10=pypi_0
182
+ threadpoolctl=3.5.0=pypi_0
183
+ tifffile=2023.7.10=pypi_0
184
+ timm=0.5.0=dev_0
185
+ tk=8.6.14=h39e8969_0
186
+ tomli=2.0.1=pypi_0
187
+ torch=2.2.2+cu118=pypi_0
188
+ torchaudio=2.2.2+cu118=pypi_0
189
+ torchvision=0.17.2+cu118=pypi_0
190
+ tqdm=4.66.5=pypi_0
191
+ triton=2.2.0=pypi_0
192
+ typing-extensions=4.10.0=pypi_0
193
+ tzdata=2024.1=pypi_0
194
+ urllib3=1.26.18=pypi_0
195
+ werkzeug=3.0.3=pypi_0
196
+ wheel=0.41.2=pypi_0
197
+ xz=5.4.6=h5eee18b_1
198
+ yaml=0.2.5=h7b6447c_0
199
+ yapf=0.40.2=pypi_0
200
+ zipp=3.18.1=pypi_0
201
+ zlib=1.2.13=h5eee18b_1
202
+ zstd=1.5.5=hc292b87_2
train_combined.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torchvision import datasets, transforms
9
+ from torch.optim.lr_scheduler import StepLR, MultiStepLR
10
+ from datasets.hkpoly_test import hktest
11
+ from datasets.original_combined_train import Combined_original
12
+ from datasets.rb_loader import RB_loader
13
+ from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
14
+ import timm
15
+ from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
16
+ from pprint import pprint
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from combined_sampler import BalancedSampler
20
+ from torch.utils.data.sampler import BatchSampler
21
+ from torch.nn.parallel import DataParallel
22
+ from model import SwinModel_domain_agnostic as Model
23
+ import matplotlib.pyplot as plt
24
+ from sklearn.metrics import roc_curve, auc
25
+ import json
26
+ from torch.utils.tensorboard import SummaryWriter
27
+
28
+ def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer):
29
+ model.train()
30
+ steploss = list()
31
+ for batch_idx, (x_cl, x_cb, target, category_cl, category_cb) in enumerate(pbar := tqdm(train_loader)):
32
+ x_cl, x_cb, target, category_cl, category_cb = x_cl.to(device), x_cb.to(device), target.to(device), category_cl.to(device), category_cb.to(device)
33
+ for optimizer in optimizers:
34
+ optimizer.zero_grad()
35
+ x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb = model(x_cl, x_cb)
36
+ loss = loss_func(x_cl, x_cb, x_cl_tokens, x_cb_tokens, target, device, domain_class_cl, domain_class_cb, category_cl, category_cb)
37
+ loss.backward()
38
+ for optimizer in optimizers:
39
+ optimizer.step()
40
+ if batch_idx % args.log_interval == 0:
41
+ if args.dry_run:
42
+ break
43
+ pbar.set_description(f"Loss {loss}")
44
+ steploss.append(loss)
45
+ return sum(steploss)/len(steploss), stepping
46
+
47
+ def l2_norm(input):
48
+ input_size = input.size()
49
+ buffer = torch.pow(input, 2)
50
+ normp = torch.sum(buffer, 1).add_(1e-12)
51
+ norm = torch.sqrt(normp)
52
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
53
+ output = _output.view(input_size)
54
+ return output
55
+
56
+ def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
57
+ model.eval()
58
+ cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
59
+ with torch.no_grad():
60
+ for (x_cl, x_cb, label) in tqdm(test_loader):
61
+ x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
62
+ x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
63
+ x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
64
+ x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
65
+ x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
66
+ label = label.cpu().detach().numpy()
67
+ cl_feats.append(x_cl_feat)
68
+ cb_feats.append(x_cb_feat)
69
+ cl_labels.append(label)
70
+ cb_labels.append(label)
71
+
72
+ cl_feats = np.concatenate(cl_feats)
73
+ cb_feats = np.concatenate(cb_feats)
74
+ cl_label = torch.from_numpy(np.concatenate(cl_labels))
75
+ cb_label = torch.from_numpy(np.concatenate(cb_labels))
76
+
77
+ # CB2CL
78
+ scores = np.dot(cl_feats,np.transpose(cb_feats))
79
+ np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+".npy", scores)
80
+ scores = scores.flatten().tolist()
81
+ labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
82
+ ids_mod = list()
83
+ for i in labels:
84
+ if i==True:
85
+ ids_mod.append(1)
86
+ else:
87
+ ids_mod.append(0)
88
+ fpr,tpr,_ = roc_curve(labels,scores)
89
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
90
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
91
+ tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
92
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
93
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
94
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
95
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
96
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
97
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
98
+ fnr = 1 - tpr
99
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
100
+ roc_auc = auc(fpr, tpr)
101
+ plt.figure()
102
+ plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
103
+ plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
104
+ plt.xlim([0, 1])
105
+ plt.ylim([0, 1])
106
+ plt.xlabel('False Positive Rate')
107
+ plt.ylabel('True Positive Rate')
108
+ plt.title('ROC Curve CB2CL task1')
109
+ plt.legend(loc="lower right")
110
+ plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+str(epoch)+".png", dpi=300, bbox_inches='tight')
111
+ print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
112
+ print(f"EER for CB2CL: {EER * 100} %")
113
+ eer_cb2cl = EER * 100
114
+ print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
115
+ print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
116
+ print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
117
+ cbcltf102 = tar_far_102 * 100
118
+ cbcltf103 = tar_far_103 * 100
119
+ cbcltf104 = tar_far_104 * 100
120
+ cl_label = cl_label.cpu().detach().numpy()
121
+ cb_label = cb_label.cpu().detach().numpy()
122
+ recall_score = Prev_RetMetric([cb_feats,cl_feats],[cb_label,cl_label],cl2cl = False)
123
+ cl2cbk1 = recall_score.recall_k(k=1) * 100
124
+ print(f"R@1 for CB2CL: {recall_score.recall_k(k=1) * 100} %")
125
+ print(f"R@10 for CB2CL: {recall_score.recall_k(k=10) * 100} %")
126
+ print(f"R@50 for CB2CL: {recall_score.recall_k(k=50) * 100} %")
127
+ print(f"R@100 for CB2CL: {recall_score.recall_k(k=100) * 100} %")
128
+
129
+ return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
130
+
131
+ def main():
132
+ # Training settings
133
+ parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
134
+ parser.add_argument('--manifest-list', type=list, default=mani_lst,
135
+ help='list of manifest files from different datasets to train on')
136
+ parser.add_argument('--batch-size', type=int, default=32, metavar='N',
137
+ help='input batch size for training (default: 64)')
138
+ parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
139
+ help='input batch size for testing (default: 1000)')
140
+ parser.add_argument('--epochs', type=int, default=50, metavar='N',
141
+ help='number of epochs to train (default: 14)')
142
+ parser.add_argument('--lr_linear', type=float, default=1.0, metavar='LR',
143
+ help='learning rate (default: 1.0)')
144
+ parser.add_argument('--lr_swin', type=float, default=1.0, metavar='LR',
145
+ help='learning rate (default: 1.0)')
146
+ parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
147
+ help='Learning rate step gamma (default: 0.7)')
148
+ parser.add_argument('--no-cuda', action='store_true', default=False,
149
+ help='disables CUDA training')
150
+ parser.add_argument('--dry-run', action='store_true', default=False,
151
+ help='quickly check a single pass')
152
+ parser.add_argument('--seed', type=int, default=1, metavar='S',
153
+ help='random seed (default: 1)')
154
+ parser.add_argument('--log-interval', type=int, default=10, metavar='N',
155
+ help='how many batches to wait before logging training status')
156
+ parser.add_argument('--warmup', type=int, default=2, metavar='N',
157
+ help='warm up rate for feature extractor')
158
+ parser.add_argument('--model-name', type=str, default="ridgeformer",
159
+ help='Name of the model for checkpointing')
160
+ args = parser.parse_args()
161
+
162
+ checkpoint_save_path = "ridgeformer_checkpoints/"
163
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
164
+
165
+ if not os.path.exists("experiment_logs/"+args.model_name):
166
+ os.mkdir("experiment_logs/"+args.model_name)
167
+
168
+ log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_linear)+str(args.lr_swin))
169
+
170
+ torch.manual_seed(args.seed)
171
+ device = torch.device("cuda" if use_cuda else "cpu")
172
+
173
+ print("loading Normal RGB images -----------------------------")
174
+ train_dataset = Combined_original(args.manifest_list,split="train")
175
+ val_dataset = hktest(split="test")
176
+
177
+ balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
178
+ batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
179
+
180
+ train_kwargs = {'batch_sampler': batch_sampler}
181
+ test_kwargs = {'batch_size': args.test_batch_size}
182
+
183
+ if use_cuda:
184
+ cuda_kwargs = {
185
+ 'num_workers': 1,
186
+ 'pin_memory': True
187
+ }
188
+ train_kwargs.update(cuda_kwargs)
189
+ test_kwargs.update(cuda_kwargs)
190
+
191
+
192
+ train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
193
+ test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
194
+
195
+ model = Model().to(device)
196
+ ckpt = torch.load("ridgeformer_checkpoints/phase1_scratch.pt", map_location=torch.device('cpu'))
197
+ model.load_state_dict(ckpt,strict=False)
198
+ print("Number of Trainable Parameters: - ", count_parameters(model))
199
+
200
+ loss_func = DualMSLoss_FineGrained_domain_agnostic()
201
+ # loss_func = DualMSLoss_FineGrained_domain_agnostic_ft()
202
+
203
+ optimizer_swin = optim.AdamW(
204
+ [
205
+ {"params": model.swin_cl.parameters(), "lr":args.lr_swin},
206
+ {"params": model.classify.parameters(), "lr":args.lr_linear},
207
+ {"params": model.linear_cl.parameters(), "lr":args.lr_linear},
208
+ {"params": model.linear_cb.parameters(), "lr":args.lr_linear},
209
+ ],
210
+ weight_decay=0.000001,
211
+ lr=args.lr_swin)
212
+
213
+ scheduler_swin = MultiStepLR(optimizer_swin, milestones = [100], gamma=0.7)
214
+
215
+ cl2cl_lst = list()
216
+ cb2cl_lst = list()
217
+ eer_cl2cl_lst = list()
218
+ eer_cb2cl_lst = list()
219
+ cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list()
220
+ stepping = 1
221
+ for epoch in range(1, args.epochs + 1):
222
+ print(f"running epoch------ {epoch}")
223
+ if (epoch > args.warmup):
224
+ print("Training with Swin")
225
+ model.unfreeze_encoder()
226
+ else:
227
+ print("Training only linear")
228
+ model.freeze_encoder()
229
+
230
+ avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_swin], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)],stepping,log_writer)
231
+
232
+ print(f"Learning Rate for {epoch} for swin = {scheduler_swin.get_last_lr()}")
233
+
234
+ log_writer.add_scalar('Swin_LR/epoch',scheduler_swin.get_last_lr()[0],epoch)
235
+
236
+ if (epoch > args.warmup):
237
+ scheduler_swin.step()
238
+
239
+ cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch,[args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)])
240
+ cl2cl_lst.append(cl2clk1)
241
+ cb2cl_lst.append(cl2cbk1)
242
+ eer_cl2cl_lst.append(eer_cl2cl)
243
+ eer_cb2cl_lst.append(eer_cb2cl)
244
+ cbcltf102_lst.append(cbcltf102)
245
+ cbcltf103_lst.append(cbcltf103)
246
+ cbcltf104_lst.append(cbcltf104)
247
+ clcltf102_lst.append(clcltf102)
248
+ clcltf103_lst.append(clcltf103)
249
+ clcltf104_lst.append(clcltf104)
250
+
251
+ log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
252
+ log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
253
+ log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
254
+ log_writer.add_scalars('TARFAR10^-3/epoch',{'CL2CL':clcltf103,'CB2CL':cbcltf103},epoch)
255
+ log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
256
+ log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
257
+
258
+ torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_linear) + "_" + str(args.lr_swin) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
259
+ log_writer.close()
260
+
261
+ print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
262
+ print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
263
+ print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
264
+ print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
265
+ print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
266
+ print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
267
+ print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
268
+ print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
269
+ print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
270
+ print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
271
+
272
+ if __name__ == '__main__':
273
+ main()
train_combined_fusion.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torchvision import datasets, transforms
9
+ from torch.optim.lr_scheduler import StepLR, MultiStepLR
10
+ from datasets.hkpoly_test import hktest
11
+ from datasets.original_combined_train import Combined_original
12
+ from datasets.rb_loader import RB_loader
13
+ from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
14
+ import timm
15
+ from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
16
+ from pprint import pprint
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from combined_sampler import BalancedSampler
20
+ from torch.utils.data.sampler import BatchSampler
21
+ from torch.nn.parallel import DataParallel
22
+ from model import SwinModel_Fusion as Model
23
+ import matplotlib.pyplot as plt
24
+ from sklearn.metrics import roc_curve, auc
25
+ import json
26
+ from torch.utils.tensorboard import SummaryWriter
27
+
28
+ def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer, checkpoint_save_path):
29
+ model.train()
30
+ steploss = list()
31
+ for batch_idx, (x_cl, x_cb, target,_,_) in enumerate(pbar := tqdm(train_loader)):
32
+ x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
33
+ for optimizer in optimizers:
34
+ optimizer.zero_grad()
35
+ x_cl_tokens, x_cb_tokens = model(x_cl, x_cb)
36
+
37
+ N, M, D = x_cl_tokens.shape
38
+
39
+ index_i = torch.arange(N).unsqueeze(1) # Shape: (100, 1)
40
+ index_j = torch.arange(N).unsqueeze(0) # Shape: (1, 100)
41
+
42
+ x = x_cl_tokens[index_i] # Shape: (100, 100, 197, 1024)
43
+ y = x_cb_tokens[index_j] # Shape: (100, 100, 197, 1024)
44
+
45
+ x = x.expand(N, N, M, D).reshape(N * N, M, D) # Shape: (10000, 197, 1024)
46
+ y = y.expand(N, N, M, D).reshape(N * N, M, D) # Shape: (10000, 197, 1024)
47
+ sim_matrix,_ = model.combine_features(x, y)
48
+ sim_matrix = sim_matrix.view(N, N).to(device)
49
+
50
+ loss = loss_func.ms_sample(sim_matrix, target).cuda() + loss_func.ms_sample(sim_matrix.t(), target.t()).cuda()
51
+ loss.backward()
52
+ for optimizer in optimizers:
53
+ optimizer.step()
54
+ if batch_idx % args.log_interval == 0:
55
+ if args.dry_run:
56
+ break
57
+ pbar.set_description(f"Loss {loss}")
58
+ steploss.append(loss)
59
+ if (batch_idx + 1)%50 == 0:
60
+ cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch, pl_arg)
61
+ log_writer.add_scalars('recall@1/step',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},stepping)
62
+ log_writer.add_scalars('EER/step',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},stepping)
63
+ log_writer.add_scalars('TARFAR10^-2/step',{'CL2CL':clcltf102,'CB2CL':cbcltf102},stepping)
64
+ log_writer.add_scalars('TARFAR10^-4/step',{'CL2CL':clcltf104,'CB2CL':cbcltf104},stepping)
65
+ stepping+=1
66
+
67
+ return sum(steploss)/len(steploss), stepping
68
+
69
+ def l2_norm(input):
70
+ input_size = input.size()
71
+ buffer = torch.pow(input, 2)
72
+ normp = torch.sum(buffer, 1).add_(1e-12)
73
+ norm = torch.sqrt(normp)
74
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
75
+ output = _output.view(input_size)
76
+ return output
77
+
78
+ def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
79
+ cl_tokens = torch.cat(cl_tokens)
80
+ cb_tokens = torch.cat(cb_tokens)
81
+ batch_size = cl_tokens.shape[0]
82
+ shard_size = 20
83
+ similarity_matrix = torch.zeros((batch_size, batch_size))
84
+ for i_start in tqdm(range(0, batch_size, shard_size)):
85
+ i_end = min(i_start + shard_size, batch_size)
86
+ shard_i = cl_tokens[i_start:i_end]
87
+ for j_start in range(0, batch_size, shard_size):
88
+ j_end = min(j_start + shard_size, batch_size)
89
+ shard_j = cb_tokens[j_start:j_end]
90
+ batch_i = shard_i.unsqueeze(1)
91
+ batch_j = shard_j.unsqueeze(0)
92
+ pairwise_i = batch_i.expand(-1, shard_size, -1, -1)
93
+ pairwise_j = batch_j.expand(shard_size, -1, -1, -1)
94
+
95
+ similarity_scores, distances = model.combine_features(pairwise_i.reshape(-1, 197, 1024), pairwise_j.reshape(-1, 197, 1024))
96
+ scores = similarity_scores - 0.1 * distances
97
+ scores = scores.reshape(shard_size, shard_size)
98
+ similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
99
+ return similarity_matrix
100
+
101
+ def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
102
+ model.eval()
103
+ cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
104
+ with torch.no_grad():
105
+ for (x_cl, x_cb, label) in tqdm(test_loader):
106
+ x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
107
+ x_cl_token = model.get_tokens(x_cl,'contactless')
108
+ x_cb_token = model.get_tokens(x_cb,'contactbased')
109
+ label = label.cpu().detach().numpy()
110
+ cl_feats.append(x_cl_token)
111
+ cb_feats.append(x_cb_token)
112
+ cl_labels.append(label)
113
+ cb_labels.append(label)
114
+
115
+ cl_label = torch.from_numpy(np.concatenate(cl_labels))
116
+ cb_label = torch.from_numpy(np.concatenate(cb_labels))
117
+
118
+ # CB2CL
119
+ scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
120
+ np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+".npy", scores_mat)
121
+ scores = scores_mat.cpu().detach().numpy().flatten().tolist()
122
+ labels = torch.eq(cb_label.view(-1,1) - cl_label.view(1,-1),0.0).flatten().tolist()
123
+ ids_mod = list()
124
+ for i in labels:
125
+ if i==True:
126
+ ids_mod.append(1)
127
+ else:
128
+ ids_mod.append(0)
129
+ fpr,tpr,_ = roc_curve(labels,scores)
130
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
131
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
132
+ tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
133
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
134
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
135
+ tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
136
+ lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
137
+ upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
138
+ tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
139
+ fnr = 1 - tpr
140
+ EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
141
+ roc_auc = auc(fpr, tpr)
142
+ plt.figure()
143
+ plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
144
+ plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
145
+ plt.xlim([0, 1])
146
+ plt.ylim([0, 1])
147
+ plt.xlabel('False Positive Rate')
148
+ plt.ylabel('True Positive Rate')
149
+ plt.title('ROC Curve CB2CL task1')
150
+ plt.legend(loc="lower right")
151
+ plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+str(epoch)+".png", dpi=300, bbox_inches='tight')
152
+ print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
153
+ print(f"EER for CB2CL: {EER * 100} %")
154
+ eer_cb2cl = EER * 100
155
+ print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
156
+ print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
157
+ print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
158
+ cbcltf102 = tar_far_102 * 100
159
+ cbcltf103 = tar_far_103 * 100
160
+ cbcltf104 = tar_far_104 * 100
161
+ cl2cbk1 = compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100
162
+ print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
163
+ print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
164
+ print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
165
+ print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
166
+ torch.cuda.empty_cache()
167
+
168
+ return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
169
+
170
+ def main():
171
+ # Training settings
172
+ parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
173
+ parser.add_argument('--manifest-list', type=list, default=mani_lst,
174
+ help='list of manifest files')
175
+ parser.add_argument('--batch-size', type=int, default=32, metavar='N',
176
+ help='input batch size for training (default: 64)')
177
+ parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
178
+ help='input batch size for testing (default: 1000)')
179
+ parser.add_argument('--epochs', type=int, default=50, metavar='N',
180
+ help='number of epochs to train (default: 14)')
181
+ parser.add_argument('--lr_fusion', type=float, default=1.0, metavar='LR',
182
+ help='learning rate (default: 1.0)')
183
+ parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
184
+ help='Learning rate step gamma (default: 0.7)')
185
+ parser.add_argument('--no-cuda', action='store_true', default=False,
186
+ help='disables CUDA training')
187
+ parser.add_argument('--dry-run', action='store_true', default=False,
188
+ help='quickly check a single pass')
189
+ parser.add_argument('--seed', type=int, default=1, metavar='S',
190
+ help='random seed (default: 1)')
191
+ parser.add_argument('--log-interval', type=int, default=10, metavar='N',
192
+ help='how many batches to wait before logging training status')
193
+ parser.add_argument('--warmup', type=int, default=2, metavar='N',
194
+ help='warm up rate for feature extractor')
195
+ parser.add_argument('--model-name', type=str, default="swinmodel",
196
+ help='Name of the model for checkpointing')
197
+ args = parser.parse_args()
198
+
199
+ device = torch.device("cuda")
200
+ model = Model().to(device)
201
+ ckpt_combined_phase1_ft = "ridgeformer_checkpoints/combined_models_check/phase1_ft_hkpoly.pt"
202
+ ckpt_combined_phase2 = "ridgeformer_checkpoints/phase2_scratch.pt"
203
+
204
+ model.load_pretrained_models(ckpt_combined_phase1_ft, ckpt_combined_phase2)
205
+ model.freeze_backbone()
206
+ checkpoint_save_path = "ridgeformer_checkpoints/"
207
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
208
+
209
+ if not os.path.exists("experiment_logs/"+args.model_name):
210
+ os.mkdir("experiment_logs/"+args.model_name)
211
+
212
+ log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_fusion))
213
+
214
+ torch.manual_seed(args.seed)
215
+
216
+ print("loading Normal RGB images -----------------------------")
217
+ train_dataset = Combined_original(args.manifest_list,split="train")
218
+ val_dataset = hktest(split="test")
219
+
220
+ balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
221
+ batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
222
+
223
+ train_kwargs = {'batch_sampler': batch_sampler}
224
+ test_kwargs = {'batch_size': args.test_batch_size}
225
+
226
+ if use_cuda:
227
+ cuda_kwargs = {
228
+ 'num_workers': 1,
229
+ 'pin_memory': True
230
+ }
231
+ train_kwargs.update(cuda_kwargs)
232
+ test_kwargs.update(cuda_kwargs)
233
+
234
+
235
+ train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
236
+ test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
237
+
238
+ print("Number of Trainable Parameters: - ", count_parameters(model))
239
+
240
+ loss_func = DualMSLoss_FineGrained()
241
+ optimizer_fusion = optim.AdamW(
242
+ [
243
+ {"params": model.output_logit_mlp.parameters(), "lr":args.lr_fusion},
244
+ {"params": model.fusion.parameters(), "lr":args.lr_fusion},
245
+ {"params": model.sep_token, "lr":args.lr_fusion},
246
+ {"params": model.encoder_layer.parameters(), "lr":args.lr_fusion},
247
+
248
+ ],
249
+ weight_decay=0.000001,
250
+ lr=args.lr_fusion)
251
+
252
+ scheduler = MultiStepLR(optimizer_fusion, milestones = [3,6,9,14], gamma=0.5)
253
+
254
+ cl2cl_lst,cb2cl_lst,eer_cl2cl_lst,eer_cb2cl_lst,cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list(),list(),list(),list(),list()
255
+ stepping = 1
256
+ for epoch in range(1, args.epochs + 1):
257
+ print(f"running epoch------ {epoch}")
258
+ avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_fusion], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_fusion)],stepping,log_writer, checkpoint_save_path)
259
+
260
+ print(f"Learning Rate for {epoch} for linear = {scheduler.get_last_lr()}")
261
+ print(f"Learning Rate for {epoch} for swin = {scheduler.get_last_lr()}")
262
+
263
+ log_writer.add_scalar('Liner_LR/epoch',scheduler.get_last_lr()[0],epoch)
264
+ log_writer.add_scalar('Swin_LR/epoch',scheduler.get_last_lr()[0],epoch)
265
+
266
+ scheduler.step()
267
+
268
+ cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch, [args.model_name,str(args.batch_size),str(args.lr_fusion)])
269
+ cl2cl_lst.append(cl2clk1)
270
+ cb2cl_lst.append(cl2cbk1)
271
+ eer_cl2cl_lst.append(eer_cl2cl)
272
+ eer_cb2cl_lst.append(eer_cb2cl)
273
+ cbcltf102_lst.append(cbcltf102)
274
+ cbcltf103_lst.append(cbcltf103)
275
+ cbcltf104_lst.append(cbcltf104)
276
+ clcltf102_lst.append(clcltf102)
277
+ clcltf103_lst.append(clcltf103)
278
+ clcltf104_lst.append(clcltf104)
279
+
280
+ log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
281
+ log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
282
+ log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
283
+ log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
284
+ log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
285
+
286
+ torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_fusion) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
287
+ log_writer.close()
288
+
289
+ print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
290
+ print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
291
+ print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
292
+ print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
293
+ print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
294
+ print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
295
+ print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
296
+ print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
297
+ print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
298
+ print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
299
+
300
+ if __name__ == '__main__':
301
+ main()
utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from tqdm import tqdm
5
+
6
+ class RetMetric(object):
7
+ def __init__(self, sim_mat, labels):
8
+ self.gallery_labels, self.query_labels = labels
9
+ self.sim_mat = sim_mat
10
+ self.is_equal_query = False
11
+
12
+ def recall_k(self, k=1):
13
+ m = len(self.sim_mat)
14
+
15
+ match_counter = 0
16
+
17
+ for i in range(m):
18
+ pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
19
+ neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
20
+
21
+ thresh = np.sort(pos_sim)[-2] if self.is_equal_query and len(pos_sim) > 1 else np.max(pos_sim)
22
+
23
+ if np.sum(neg_sim > thresh) < k:
24
+ match_counter += 1
25
+ return float(match_counter) / m
26
+
27
+ class Prev_RetMetric(object):
28
+ def __init__(self, feats, labels, cl2cl=True):
29
+
30
+ if len(feats) == 2 and type(feats) == list:
31
+ """
32
+ feats = [gallery_feats, query_feats]
33
+ labels = [gallery_labels, query_labels]
34
+ """
35
+ self.is_equal_query = False
36
+
37
+ self.gallery_feats, self.query_feats = feats
38
+ self.gallery_labels, self.query_labels = labels
39
+
40
+ else:
41
+ self.is_equal_query = True
42
+ self.gallery_feats = self.query_feats = feats
43
+ self.gallery_labels = self.query_labels = labels
44
+
45
+ self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
46
+ if cl2cl:
47
+ self.sim_mat = self.sim_mat * (1 - np.identity(self.sim_mat.shape[0]))
48
+
49
+ def recall_k(self, k=1):
50
+ m = len(self.sim_mat)
51
+
52
+ match_counter = 0
53
+
54
+ for i in range(m):
55
+ pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
56
+ neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
57
+
58
+ thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
59
+
60
+ if np.sum(neg_sim > thresh) < k:
61
+ match_counter += 1
62
+ return float(match_counter) / m
63
+
64
+ def compute_recall_at_k(similarity_matrix, p_labels, g_labels, k):
65
+ num_probes = p_labels.size(0)
66
+ recall_at_k = 0.0
67
+ for i in range(num_probes):
68
+ probe_label = p_labels[i]
69
+ sim_scores = similarity_matrix[i]
70
+ sorted_indices = torch.argsort(sim_scores, descending=True)
71
+ top_k_indices = sorted_indices[:k]
72
+ correct_in_top_k = any(g_labels[idx] == probe_label for idx in top_k_indices)
73
+ recall_at_k += correct_in_top_k
74
+ recall_at_k /= num_probes
75
+ return recall_at_k
76
+
77
+ def count_parameters(model):
78
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
79
+
80
+ def l2_norm(input):
81
+ input_size = input.size()
82
+ buffer = torch.pow(input, 2)
83
+ normp = torch.sum(buffer, 1).add_(1e-12)
84
+ norm = torch.sqrt(normp)
85
+ _output = torch.div(input, norm.view(-1, 1).expand_as(input))
86
+ output = _output.view(input_size)
87
+
88
+ return output
89
+
90
+ def compute_sharded_cosine_similarity(tensor1, tensor2, shard_size):
91
+ B, T, D = tensor1.shape
92
+ average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
93
+
94
+ for start_idx1 in tqdm(range(0, B, shard_size)):
95
+ end_idx1 = min(start_idx1 + shard_size, B)
96
+
97
+ for start_idx2 in range(0, B, shard_size):
98
+ end_idx2 = min(start_idx2 + shard_size, B)
99
+
100
+ # Get the shard
101
+ shard_tensor1 = tensor1[start_idx1:end_idx1]
102
+ shard_tensor2 = tensor2[start_idx2:end_idx2]
103
+
104
+ # Reshape and expand
105
+ shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
106
+ shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
107
+
108
+ # Compute cosine similarity for the shard
109
+ shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
110
+
111
+ # Sum up the cosine similarities
112
+ average_sim_matrix[start_idx1:end_idx1, start_idx2:end_idx2] += torch.sum(shard_cos_sim, dim=[2, 3])
113
+
114
+ # Normalize by the total number of elements (T*T)
115
+ average_sim_matrix /= (T * T)
116
+
117
+ return average_sim_matrix