Upload 11 files
Browse files- combined_sampler.py +30 -0
- hkpoly_evaluation_phase1.py +106 -0
- hkpoly_evaluation_phase2.py +114 -0
- loss.py +377 -0
- model.py +207 -0
- rb_evaluation_phase1.py +148 -0
- rb_evaluation_phase2.py +164 -0
- requirements.txt +202 -0
- train_combined.py +273 -0
- train_combined_fusion.py +301 -0
- utils.py +117 -0
combined_sampler.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data.sampler import Sampler
|
5 |
+
from tqdm import *
|
6 |
+
|
7 |
+
class BalancedSampler(Sampler):
|
8 |
+
def __init__(self, data_source, batch_size, images_per_class=3):
|
9 |
+
self.data_source = data_source
|
10 |
+
self.ys = np.array(data_source.all_labels)
|
11 |
+
self.num_groups = batch_size // images_per_class
|
12 |
+
self.batch_size = batch_size
|
13 |
+
self.num_instances = images_per_class
|
14 |
+
self.num_samples = len(self.ys)
|
15 |
+
self.num_classes = len(set(self.ys))
|
16 |
+
|
17 |
+
def __len__(self):
|
18 |
+
return self.num_samples
|
19 |
+
|
20 |
+
def __iter__(self):
|
21 |
+
num_batches = len(self.data_source) // self.batch_size
|
22 |
+
ret = []
|
23 |
+
while num_batches > 0:
|
24 |
+
sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False)
|
25 |
+
for i in range(len(sampled_classes)):
|
26 |
+
ith_class_idxs = np.nonzero(self.ys == sampled_classes[i])[0]
|
27 |
+
class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True)
|
28 |
+
ret.extend(np.random.permutation(class_sel))
|
29 |
+
num_batches -= 1
|
30 |
+
return iter(ret)
|
hkpoly_evaluation_phase1.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# script to evaluated HKPolyU testing dataset on finetuned model after phase 1
|
2 |
+
import torch
|
3 |
+
from datasets.hkpoly_test import hktest
|
4 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from model import SwinModel_domain_agnostic as Model
|
8 |
+
from sklearn.metrics import roc_curve, auc
|
9 |
+
import json
|
10 |
+
|
11 |
+
def calculate_tar_at_far(fpr, tpr, target_fars):
|
12 |
+
tar_at_far = {}
|
13 |
+
for far in target_fars:
|
14 |
+
if far in fpr:
|
15 |
+
tar = tpr[np.where(fpr == far)][0]
|
16 |
+
else:
|
17 |
+
tar = np.interp(far, fpr, tpr)
|
18 |
+
tar_at_far[far] = tar
|
19 |
+
return tar_at_far
|
20 |
+
|
21 |
+
if __name__ == '__main__':
|
22 |
+
device = torch.device('cuda')
|
23 |
+
data = hktest(split = 'test')
|
24 |
+
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
|
25 |
+
model = Model().to(device)
|
26 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase1_ft_hkpoly.pt",map_location = torch.device('cpu'))
|
27 |
+
model.load_state_dict(checkpoint,strict=False)
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list()
|
31 |
+
with torch.no_grad():
|
32 |
+
for (x_cl, x_cb, label) in tqdm(dataloader):
|
33 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
34 |
+
x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
|
35 |
+
x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
|
36 |
+
cl_feats_unnormed.append(x_cl_feat.cpu().detach().numpy())
|
37 |
+
cb_feats_unnormed.append(x_cb_feat.cpu().detach().numpy())
|
38 |
+
x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
|
39 |
+
x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
|
40 |
+
label = label.cpu().detach().numpy()
|
41 |
+
cl_feats.append(x_cl_feat)
|
42 |
+
cb_feats.append(x_cb_feat)
|
43 |
+
cl_labels.append(label)
|
44 |
+
cb_labels.append(label)
|
45 |
+
|
46 |
+
cl_feats = np.concatenate(cl_feats)
|
47 |
+
cb_feats = np.concatenate(cb_feats)
|
48 |
+
cl_feats_unnormed = np.concatenate(cl_feats_unnormed)
|
49 |
+
cb_feats_unnormed = np.concatenate(cb_feats_unnormed)
|
50 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
51 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
52 |
+
|
53 |
+
# CB2CL
|
54 |
+
squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
|
55 |
+
distance = -1 * np.sqrt(squared_diff)
|
56 |
+
similarities = np.dot(cl_feats,np.transpose(cb_feats))
|
57 |
+
scores_mat = similarities + 0.1 * distance
|
58 |
+
|
59 |
+
scores = scores_mat.flatten().tolist()
|
60 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
61 |
+
ids_mod = list()
|
62 |
+
for i in labels:
|
63 |
+
if i==True:
|
64 |
+
ids_mod.append(1)
|
65 |
+
else:
|
66 |
+
ids_mod.append(0)
|
67 |
+
|
68 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
69 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
70 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
71 |
+
tar_far_102 = tpr[upper_fpr_idx]
|
72 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
|
73 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
|
74 |
+
|
75 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
76 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
77 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
78 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
79 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
80 |
+
|
81 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
82 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
83 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
84 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
85 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
86 |
+
|
87 |
+
fnr = 1 - tpr
|
88 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
89 |
+
roc_auc = auc(fpr, tpr)
|
90 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
91 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
92 |
+
eer_cb2cl = EER * 100
|
93 |
+
|
94 |
+
cbcltf102 = tar_far_102 * 100
|
95 |
+
cbcltf103 = tar_far_103 * 100
|
96 |
+
cbcltf104 = tar_far_104 * 100
|
97 |
+
cl_label = cl_label.cpu().detach()
|
98 |
+
cb_label = cb_label.cpu().detach()
|
99 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
100 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
101 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
102 |
+
|
103 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 1) * 100} %")
|
104 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 10) * 100} %")
|
105 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 50) * 100} %")
|
106 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 100) * 100} %")
|
hkpoly_evaluation_phase2.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# script to evaluate HKPolyU testing dataset on finetuned model after phase 2
|
2 |
+
import torch
|
3 |
+
from datasets.hkpoly_test import hktest
|
4 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from model import SwinModel_Fusion as Model
|
8 |
+
from sklearn.metrics import roc_curve, auc
|
9 |
+
import json
|
10 |
+
|
11 |
+
def calculate_tar_at_far(fpr, tpr, target_fars):
|
12 |
+
tar_at_far = {}
|
13 |
+
for far in target_fars:
|
14 |
+
if far in fpr:
|
15 |
+
tar = tpr[np.where(fpr == far)][0]
|
16 |
+
else:
|
17 |
+
tar = np.interp(far, fpr, tpr)
|
18 |
+
tar_at_far[far] = tar
|
19 |
+
return tar_at_far
|
20 |
+
|
21 |
+
def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
|
22 |
+
cl_tokens = torch.cat(cl_tokens)
|
23 |
+
cb_tokens = torch.cat(cb_tokens)
|
24 |
+
batch_size = cl_tokens.shape[0]
|
25 |
+
shard_size = 20
|
26 |
+
similarity_matrix = torch.zeros((batch_size, batch_size))
|
27 |
+
for i_start in tqdm(range(0, batch_size, shard_size)):
|
28 |
+
i_end = min(i_start + shard_size, batch_size)
|
29 |
+
shard_i = cl_tokens[i_start:i_end]
|
30 |
+
for j_start in range(0, batch_size, shard_size):
|
31 |
+
j_end = min(j_start + shard_size, batch_size)
|
32 |
+
shard_j = cb_tokens[j_start:j_end]
|
33 |
+
batch_i = shard_i.unsqueeze(1)
|
34 |
+
batch_j = shard_j.unsqueeze(0)
|
35 |
+
pairwise_i = batch_i.expand(-1, shard_size, -1, -1)
|
36 |
+
pairwise_j = batch_j.expand(shard_size, -1, -1, -1)
|
37 |
+
similarity_scores, distances = model.combine_features(pairwise_i.reshape(-1, 197, 1024), pairwise_j.reshape(-1, 197, 1024))
|
38 |
+
scores = similarity_scores - 0.1 * distances
|
39 |
+
scores = scores.reshape(shard_size, shard_size)
|
40 |
+
similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
|
41 |
+
return similarity_matrix
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
device = torch.device('cuda')
|
45 |
+
data = hktest(split = 'test')
|
46 |
+
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
|
47 |
+
model = Model().to(device)
|
48 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase2_ft_hkpoly.pt",map_location = torch.device('cpu'))
|
49 |
+
model.load_state_dict(checkpoint,strict=False)
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list()
|
53 |
+
with torch.no_grad():
|
54 |
+
for (x_cl, x_cb, label) in tqdm(dataloader):
|
55 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
56 |
+
x_cl_token = model.get_tokens(x_cl,'contactless')
|
57 |
+
x_cb_token = model.get_tokens(x_cb,'contactbased')
|
58 |
+
label = label.cpu().detach().numpy()
|
59 |
+
cl_feats.append(x_cl_token)
|
60 |
+
cb_feats.append(x_cb_token)
|
61 |
+
cl_labels.append(label)
|
62 |
+
cb_labels.append(label)
|
63 |
+
|
64 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
65 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
66 |
+
|
67 |
+
# CB2CL
|
68 |
+
scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
|
69 |
+
scores = scores_mat.cpu().detach().numpy().flatten().tolist()
|
70 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
71 |
+
ids_mod = list()
|
72 |
+
for i in labels:
|
73 |
+
if i==True:
|
74 |
+
ids_mod.append(1)
|
75 |
+
else:
|
76 |
+
ids_mod.append(0)
|
77 |
+
|
78 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
79 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
80 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
81 |
+
tar_far_102 = tpr[upper_fpr_idx]
|
82 |
+
|
83 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
84 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
85 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
86 |
+
|
87 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
88 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
89 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
90 |
+
|
91 |
+
fnr = 1 - tpr
|
92 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
93 |
+
roc_auc = auc(fpr, tpr)
|
94 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
95 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
96 |
+
eer_cb2cl = EER * 100
|
97 |
+
cbcltf102 = tar_far_102 * 100
|
98 |
+
cbcltf103 = tar_far_103 * 100
|
99 |
+
cbcltf104 = tar_far_104 * 100
|
100 |
+
cl_label = cl_label.cpu().detach()
|
101 |
+
cb_label = cb_label.cpu().detach()
|
102 |
+
|
103 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
104 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
105 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
106 |
+
|
107 |
+
recall_dict = dict()
|
108 |
+
for i in range(1,101):
|
109 |
+
recall_dict[i] = compute_recall_at_k(scores_mat, cl_label, cb_label, i)
|
110 |
+
|
111 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
|
112 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
|
113 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
|
114 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
|
loss.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_metric_learning import losses
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.init
|
5 |
+
import torchvision.models as models
|
6 |
+
from torch.autograd import Variable
|
7 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
8 |
+
from torch.nn.utils.weight_norm import weight_norm
|
9 |
+
import torch.backends.cudnn as cudnn
|
10 |
+
from torch.nn.utils.clip_grad import clip_grad_norm
|
11 |
+
import numpy as np
|
12 |
+
import os
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import itertools
|
15 |
+
|
16 |
+
torch.autograd.set_detect_anomaly(True)
|
17 |
+
|
18 |
+
class DualMSLoss_FineGrained(nn.Module):
|
19 |
+
"""
|
20 |
+
Compute contrastive loss
|
21 |
+
"""
|
22 |
+
def __init__(self, margin=0, max_violation=False):
|
23 |
+
super(DualMSLoss_FineGrained, self).__init__()
|
24 |
+
self.margin = margin
|
25 |
+
self.max_violation = max_violation
|
26 |
+
self.thresh = 0.5
|
27 |
+
self.margin = 0.7 # 0.1
|
28 |
+
self.scale_pos = 2
|
29 |
+
self.scale_neg = 40.0
|
30 |
+
|
31 |
+
def ms_sample(self,sim_mat,label):
|
32 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
33 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
34 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
35 |
+
neg_mask = 1 - pos_mask
|
36 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
37 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
38 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
39 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
40 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
41 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
42 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
43 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
44 |
+
|
45 |
+
return pos_loss + neg_loss
|
46 |
+
|
47 |
+
def ms_sample_cbcb_clcl(self,sim_mat,label):
|
48 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
49 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
50 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
51 |
+
|
52 |
+
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
|
53 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
54 |
+
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
55 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
56 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
57 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
58 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
59 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
60 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
61 |
+
|
62 |
+
return pos_loss + neg_loss
|
63 |
+
|
64 |
+
def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
|
65 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
66 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
67 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
68 |
+
|
69 |
+
n_sha = pos_mask.shape[0]
|
70 |
+
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
|
71 |
+
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
|
72 |
+
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
|
73 |
+
|
74 |
+
neg_mask = 1-pos_mask
|
75 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
76 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
77 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
78 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
79 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
80 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
81 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
82 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
83 |
+
|
84 |
+
return pos_loss + neg_loss
|
85 |
+
|
86 |
+
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
|
87 |
+
B, T, D = tensor1.shape
|
88 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
89 |
+
|
90 |
+
for start_idx in range(0, T, shard_size):
|
91 |
+
end_idx = min(start_idx + shard_size, T)
|
92 |
+
|
93 |
+
# Get the shard
|
94 |
+
shard_tensor1 = tensor1[:, start_idx:end_idx, :]
|
95 |
+
shard_tensor2 = tensor2[:, start_idx:end_idx, :]
|
96 |
+
|
97 |
+
# Reshape and expand
|
98 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
99 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
100 |
+
|
101 |
+
# Compute cosine similarity for the shard
|
102 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
103 |
+
|
104 |
+
# Accumulate the sum of cosine similarities
|
105 |
+
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
|
106 |
+
|
107 |
+
# Normalize by the total number of elements (T*T)
|
108 |
+
average_sim_matrix /= (T * T)
|
109 |
+
|
110 |
+
return average_sim_matrix
|
111 |
+
|
112 |
+
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device):
|
113 |
+
|
114 |
+
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
|
115 |
+
n = sim_mat_clcl.shape[0]
|
116 |
+
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
|
117 |
+
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
|
118 |
+
|
119 |
+
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
|
120 |
+
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
|
121 |
+
|
122 |
+
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
|
123 |
+
return loss4 + loss2 + loss3#+ (1.5*loss2) + (1.5*loss3) # + loss2 + loss3#+ loss5 # 0.1*loss5 + loss3
|
124 |
+
|
125 |
+
def l2_norm(self, input):
|
126 |
+
input_size = input.size()
|
127 |
+
buffer = torch.pow(input, 2)
|
128 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
129 |
+
norm = torch.sqrt(normp)
|
130 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
131 |
+
output = _output.view(input_size)
|
132 |
+
|
133 |
+
return output
|
134 |
+
|
135 |
+
class DualMSLoss_FineGrained_domain_agnostic(nn.Module):
|
136 |
+
"""
|
137 |
+
Compute contrastive loss
|
138 |
+
"""
|
139 |
+
def __init__(self, margin=0, max_violation=False):
|
140 |
+
super(DualMSLoss_FineGrained_domain_agnostic, self).__init__()
|
141 |
+
self.margin = margin
|
142 |
+
self.max_violation = max_violation
|
143 |
+
self.thresh = 0.5
|
144 |
+
self.margin = 0.5 # 0.1
|
145 |
+
self.scale_pos = 2
|
146 |
+
self.scale_neg = 40.0
|
147 |
+
self.criterion = nn.CrossEntropyLoss()
|
148 |
+
|
149 |
+
def ms_sample(self,sim_mat,label):
|
150 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
151 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
152 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
153 |
+
neg_mask = 1 - pos_mask
|
154 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
155 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
156 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
157 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
158 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
159 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
160 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
161 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
162 |
+
|
163 |
+
return pos_loss + neg_loss
|
164 |
+
|
165 |
+
def ms_sample_cbcb_clcl(self,sim_mat,label):
|
166 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
167 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
168 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
169 |
+
|
170 |
+
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
|
171 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
172 |
+
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
173 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
174 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
175 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
176 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
177 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
178 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
179 |
+
|
180 |
+
return pos_loss + neg_loss
|
181 |
+
|
182 |
+
def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
|
183 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
184 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
185 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
186 |
+
|
187 |
+
n_sha = pos_mask.shape[0]
|
188 |
+
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
|
189 |
+
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
|
190 |
+
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
|
191 |
+
|
192 |
+
neg_mask = 1-pos_mask
|
193 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
194 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
195 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
196 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
197 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
198 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
199 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
200 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
201 |
+
|
202 |
+
return pos_loss + neg_loss
|
203 |
+
|
204 |
+
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
|
205 |
+
B, T, D = tensor1.shape
|
206 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
207 |
+
|
208 |
+
for start_idx in range(0, T, shard_size):
|
209 |
+
end_idx = min(start_idx + shard_size, T)
|
210 |
+
|
211 |
+
# Get the shard
|
212 |
+
shard_tensor1 = tensor1[:, start_idx:end_idx, :]
|
213 |
+
shard_tensor2 = tensor2[:, start_idx:end_idx, :]
|
214 |
+
|
215 |
+
# Reshape and expand
|
216 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
217 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
218 |
+
|
219 |
+
# Compute cosine similarity for the shard
|
220 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
221 |
+
|
222 |
+
# Accumulate the sum of cosine similarities
|
223 |
+
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
|
224 |
+
|
225 |
+
# Normalize by the total number of elements (T*T)
|
226 |
+
average_sim_matrix /= (T * T)
|
227 |
+
|
228 |
+
return average_sim_matrix
|
229 |
+
|
230 |
+
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt):
|
231 |
+
|
232 |
+
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
|
233 |
+
n = sim_mat_clcl.shape[0]
|
234 |
+
|
235 |
+
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
|
236 |
+
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
|
237 |
+
|
238 |
+
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
|
239 |
+
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
|
240 |
+
|
241 |
+
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
|
242 |
+
|
243 |
+
pred = torch.cat([domain_class_cl, domain_class_cb])
|
244 |
+
gt = torch.cat([domain_class_cl_gt, domain_class_cb_gt])
|
245 |
+
|
246 |
+
domain_class_loss = self.criterion(pred,gt)
|
247 |
+
return loss4 + loss2 + loss3 + (3*domain_class_loss)
|
248 |
+
|
249 |
+
def l2_norm(self, input):
|
250 |
+
input_size = input.size()
|
251 |
+
buffer = torch.pow(input, 2)
|
252 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
253 |
+
norm = torch.sqrt(normp)
|
254 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
255 |
+
output = _output.view(input_size)
|
256 |
+
|
257 |
+
return output
|
258 |
+
|
259 |
+
class DualMSLoss_FineGrained_domain_agnostic_ft(nn.Module):
|
260 |
+
"""
|
261 |
+
Compute contrastive loss
|
262 |
+
"""
|
263 |
+
def __init__(self, margin=0, max_violation=False):
|
264 |
+
super(DualMSLoss_FineGrained_domain_agnostic_ft, self).__init__()
|
265 |
+
self.margin = margin
|
266 |
+
self.max_violation = max_violation
|
267 |
+
self.thresh = 0.5
|
268 |
+
self.margin = 0.7 # 0.1
|
269 |
+
self.scale_pos = 2
|
270 |
+
self.scale_neg = 40.0
|
271 |
+
self.criterion = nn.CrossEntropyLoss()
|
272 |
+
|
273 |
+
def ms_sample(self,sim_mat,label):
|
274 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
275 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
276 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
277 |
+
neg_mask = 1 - pos_mask
|
278 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
279 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
280 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
281 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
282 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
283 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
284 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
285 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
286 |
+
|
287 |
+
return pos_loss + neg_loss
|
288 |
+
|
289 |
+
def ms_sample_cbcb_clcl(self,sim_mat,label):
|
290 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
291 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
292 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
293 |
+
|
294 |
+
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
|
295 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
296 |
+
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
297 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
298 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
299 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
300 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
301 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
302 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
303 |
+
|
304 |
+
return pos_loss + neg_loss
|
305 |
+
|
306 |
+
def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
|
307 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
308 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
309 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
310 |
+
|
311 |
+
n_sha = pos_mask.shape[0]
|
312 |
+
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
|
313 |
+
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
|
314 |
+
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
|
315 |
+
|
316 |
+
neg_mask = 1-pos_mask
|
317 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
318 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
319 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
320 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
321 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
322 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
323 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
324 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
325 |
+
|
326 |
+
return pos_loss + neg_loss
|
327 |
+
|
328 |
+
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
|
329 |
+
B, T, D = tensor1.shape
|
330 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
331 |
+
|
332 |
+
for start_idx in range(0, T, shard_size):
|
333 |
+
end_idx = min(start_idx + shard_size, T)
|
334 |
+
|
335 |
+
# Get the shard
|
336 |
+
shard_tensor1 = tensor1[:, start_idx:end_idx, :]
|
337 |
+
shard_tensor2 = tensor2[:, start_idx:end_idx, :]
|
338 |
+
|
339 |
+
# Reshape and expand
|
340 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
341 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
342 |
+
|
343 |
+
# Compute cosine similarity for the shard
|
344 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
345 |
+
|
346 |
+
# Accumulate the sum of cosine similarities
|
347 |
+
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
|
348 |
+
|
349 |
+
# Normalize by the total number of elements (T*T)
|
350 |
+
average_sim_matrix /= (T * T)
|
351 |
+
|
352 |
+
return average_sim_matrix
|
353 |
+
|
354 |
+
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt):
|
355 |
+
|
356 |
+
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
|
357 |
+
n = sim_mat_clcl.shape[0]
|
358 |
+
|
359 |
+
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
|
360 |
+
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
|
361 |
+
|
362 |
+
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
|
363 |
+
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
|
364 |
+
|
365 |
+
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
|
366 |
+
|
367 |
+
return loss4 + loss2 + loss3
|
368 |
+
|
369 |
+
def l2_norm(self, input):
|
370 |
+
input_size = input.size()
|
371 |
+
buffer = torch.pow(input, 2)
|
372 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
373 |
+
norm = torch.sqrt(normp)
|
374 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
375 |
+
output = _output.view(input_size)
|
376 |
+
|
377 |
+
return output
|
model.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.optim as optim
|
7 |
+
from torchvision import datasets, transforms
|
8 |
+
from torch.optim.lr_scheduler import StepLR
|
9 |
+
import torchvision.models as models
|
10 |
+
import timm
|
11 |
+
from pprint import pprint
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
from torch.utils.data.sampler import BatchSampler
|
15 |
+
from gradient_reversal.module import GradientReversal
|
16 |
+
|
17 |
+
class SwinModel(nn.Module):
|
18 |
+
def __init__(self):
|
19 |
+
super().__init__()
|
20 |
+
self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
21 |
+
self.swin_cb = self.swin_cl
|
22 |
+
|
23 |
+
self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.Linear(1024, 1024))
|
26 |
+
self.linear_cb = nn.Linear(1024, 1024)
|
27 |
+
|
28 |
+
def freeze_encoder(self):
|
29 |
+
for param in self.swin_cl.parameters():
|
30 |
+
param.requires_grad = False
|
31 |
+
for param in self.swin_cb.parameters():
|
32 |
+
param.requires_grad = False
|
33 |
+
|
34 |
+
def unfreeze_encoder(self):
|
35 |
+
for param in self.swin_cl.parameters():
|
36 |
+
param.requires_grad = True
|
37 |
+
for param in self.swin_cb.parameters():
|
38 |
+
param.requires_grad = True
|
39 |
+
|
40 |
+
def get_embeddings(self, image, ftype):
|
41 |
+
linear = self.linear_cl if ftype == "contactless" else self.linear_cl
|
42 |
+
swin = self.swin_cl if ftype == "contactless" else self.swin_cb
|
43 |
+
|
44 |
+
tokens = swin(image)
|
45 |
+
emb_mean = tokens.mean(dim=1)
|
46 |
+
feat = linear(emb_mean)
|
47 |
+
tokens_transformed = linear(tokens)
|
48 |
+
return feat, tokens
|
49 |
+
|
50 |
+
def forward(self, x_cl, x_cb):
|
51 |
+
x_cl_tokens = self.swin_cl(x_cl)
|
52 |
+
x_cb_tokens = self.swin_cb(x_cb)
|
53 |
+
|
54 |
+
x_cl_mean = x_cl_tokens.mean(dim=1)
|
55 |
+
x_cb_mean = x_cb_tokens.mean(dim=1)
|
56 |
+
|
57 |
+
x_cl = self.linear_cl(x_cl_mean)
|
58 |
+
x_cl_tokens_transformed = self.linear_cl(x_cl_tokens)
|
59 |
+
|
60 |
+
x_cb = self.linear_cl(x_cb_mean)
|
61 |
+
x_cb_tokens_transformed = self.linear_cl(x_cb_tokens)
|
62 |
+
|
63 |
+
return x_cl, x_cb, x_cl_tokens, x_cb_tokens
|
64 |
+
|
65 |
+
class SwinModel_domain_agnostic(nn.Module):
|
66 |
+
def __init__(self):
|
67 |
+
super().__init__()
|
68 |
+
self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
69 |
+
self.swin_cb = self.swin_cl #timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
70 |
+
|
71 |
+
self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
|
72 |
+
nn.ReLU(),
|
73 |
+
nn.Linear(1024, 1024))
|
74 |
+
self.linear_cb = nn.Linear(1024, 1024)
|
75 |
+
self.classify = nn.Sequential(GradientReversal(alpha=0.6), # original 0.8
|
76 |
+
nn.Linear(1024,512),
|
77 |
+
nn.ReLU(),
|
78 |
+
nn.Linear(512,8))
|
79 |
+
|
80 |
+
def freeze_encoder(self):
|
81 |
+
for param in self.swin_cl.parameters():
|
82 |
+
param.requires_grad = False
|
83 |
+
for param in self.swin_cb.parameters():
|
84 |
+
param.requires_grad = False
|
85 |
+
|
86 |
+
def unfreeze_encoder(self):
|
87 |
+
for param in self.swin_cl.parameters():
|
88 |
+
param.requires_grad = True
|
89 |
+
for param in self.swin_cb.parameters():
|
90 |
+
param.requires_grad = True
|
91 |
+
|
92 |
+
def get_embeddings(self, image, ftype):
|
93 |
+
linear = self.linear_cl if ftype == "contactless" else self.linear_cl
|
94 |
+
swin = self.swin_cl if ftype == "contactless" else self.swin_cb
|
95 |
+
|
96 |
+
tokens = swin(image)
|
97 |
+
emb_mean = tokens.mean(dim=1)
|
98 |
+
feat = linear(emb_mean)
|
99 |
+
tokens_transformed = linear(tokens)
|
100 |
+
return feat, tokens
|
101 |
+
|
102 |
+
def forward(self, x_cl, x_cb):
|
103 |
+
x_cl_tokens = self.swin_cl(x_cl)
|
104 |
+
x_cb_tokens = self.swin_cb(x_cb)
|
105 |
+
|
106 |
+
x_cl_mean = x_cl_tokens.mean(dim=1)
|
107 |
+
x_cb_mean = x_cb_tokens.mean(dim=1)
|
108 |
+
|
109 |
+
x_cl = self.linear_cl(x_cl_mean)
|
110 |
+
x_cl_tokens_transformed = self.linear_cl(x_cl_tokens)
|
111 |
+
|
112 |
+
x_cb = self.linear_cl(x_cb_mean)
|
113 |
+
x_cb_tokens_transformed = self.linear_cl(x_cb_tokens)
|
114 |
+
|
115 |
+
domain_class_cl = self.classify(x_cl_mean)
|
116 |
+
domain_class_cb = self.classify(x_cb_mean)
|
117 |
+
|
118 |
+
return x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb
|
119 |
+
|
120 |
+
class SwinModel_Fusion(nn.Module):
|
121 |
+
def __init__(self):
|
122 |
+
super().__init__()
|
123 |
+
self.feature_dim = 1024
|
124 |
+
self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
125 |
+
self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.feature_dim, nhead=4, dropout=0.5, batch_first=True, norm_first=True, activation="gelu")
|
126 |
+
self.fusion = nn.TransformerEncoder(self.encoder_layer, num_layers=2)
|
127 |
+
self.sep_token = nn.Parameter(torch.randn(1, 1, self.feature_dim))
|
128 |
+
self.output_logit_mlp = nn.Sequential(nn.Linear(1024, 512),
|
129 |
+
nn.ReLU(),
|
130 |
+
nn.Dropout(),
|
131 |
+
nn.Linear(512, 1))
|
132 |
+
self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
|
133 |
+
nn.ReLU(),
|
134 |
+
nn.Linear(1024, 1024))
|
135 |
+
|
136 |
+
def load_pretrained_models(self, swin_cl_path, fusion_ckpt_path):
|
137 |
+
swin_cl_state_dict = torch.load(swin_cl_path)
|
138 |
+
new_dict = {}
|
139 |
+
for key in swin_cl_state_dict.keys():
|
140 |
+
if "swin_cl" in key:
|
141 |
+
new_dict[key.replace("swin_cl.","")] = swin_cl_state_dict[key]
|
142 |
+
self.swin_cl.load_state_dict(new_dict)
|
143 |
+
|
144 |
+
fusion_params = torch.load(fusion_ckpt_path)
|
145 |
+
new_dict = {}
|
146 |
+
for key in fusion_params.keys():
|
147 |
+
if "encoder_layer" in key:
|
148 |
+
new_dict[key.replace("encoder_layer.","")] = fusion_params[key]
|
149 |
+
self.encoder_layer.load_state_dict(new_dict)
|
150 |
+
|
151 |
+
new_dict = {}
|
152 |
+
for key in fusion_params.keys():
|
153 |
+
if "fusion" in key:
|
154 |
+
new_dict[key.replace("fusion.","")] = fusion_params[key]
|
155 |
+
self.fusion.load_state_dict(new_dict)
|
156 |
+
|
157 |
+
self.sep_token = nn.Parameter(fusion_params["sep_token"])
|
158 |
+
|
159 |
+
new_dict = {}
|
160 |
+
for key in fusion_params.keys():
|
161 |
+
if "output_logit_mlp" in key:
|
162 |
+
new_dict[key.replace("output_logit_mlp.","")] = fusion_params[key]
|
163 |
+
self.output_logit_mlp.load_state_dict(new_dict)
|
164 |
+
|
165 |
+
def l2_norm(self,input):
|
166 |
+
input_size = input.shape[0]
|
167 |
+
buffer = torch.pow(input, 2)
|
168 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
169 |
+
norm = torch.sqrt(normp)
|
170 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
171 |
+
return _output
|
172 |
+
|
173 |
+
def combine_features(self, fingerprint_1_tokens, fingerprint_2_tokens):
|
174 |
+
# This function takes a pair of embeddings [B, 49, 1024], [B, 49, 1024] and returns a B logit scores [B]
|
175 |
+
# fingerprint_1_tokens = self.linear_cl(fingerprint_1_tokens)
|
176 |
+
# fingerprint_2_tokens = self.linear_cl(fingerprint_2_tokens)
|
177 |
+
batch_size = fingerprint_1_tokens.shape[0]
|
178 |
+
sep_token = self.sep_token.repeat(batch_size, 1, 1)
|
179 |
+
combine_features = torch.cat((fingerprint_1_tokens, sep_token, fingerprint_2_tokens), dim=1)
|
180 |
+
fused_match_representation = self.fusion(combine_features)
|
181 |
+
fingerprint_1 = fused_match_representation[:,:197,:].mean(dim=1)
|
182 |
+
fingerprint_2 = fused_match_representation[:,198:,:].mean(dim=1)
|
183 |
+
|
184 |
+
fingerprint_1_norm = self.l2_norm(fingerprint_1)
|
185 |
+
fingerprint_2_norm = self.l2_norm(fingerprint_2)
|
186 |
+
|
187 |
+
similarities = torch.sum(fingerprint_1_norm * fingerprint_2_norm, axis=1)
|
188 |
+
|
189 |
+
differences = fingerprint_1 - fingerprint_2
|
190 |
+
squared_differences = differences ** 2
|
191 |
+
sum_squared_differences = torch.sum(squared_differences, axis=1)
|
192 |
+
distances = torch.sqrt(sum_squared_differences)
|
193 |
+
return similarities, distances
|
194 |
+
|
195 |
+
def get_tokens(self, image, ftype):
|
196 |
+
swin = self.swin_cl
|
197 |
+
tokens = swin(image)
|
198 |
+
return tokens
|
199 |
+
|
200 |
+
def freeze_backbone(self):
|
201 |
+
for param in self.swin_cl.parameters():
|
202 |
+
param.requires_grad = False
|
203 |
+
|
204 |
+
def forward(self, x_cl, x_cb):
|
205 |
+
x_cl_tokens = self.swin_cl(x_cl)
|
206 |
+
x_cb_tokens = self.swin_cl(x_cb)
|
207 |
+
return x_cl_tokens, x_cb_tokens
|
rb_evaluation_phase1.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from datasets.rb_loader import RB_loader
|
3 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from model import SwinModel_domain_agnostic as Model
|
7 |
+
from sklearn.metrics import roc_curve, auc
|
8 |
+
import json
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
device = torch.device('cuda')
|
13 |
+
data = RB_loader(split = 'test')
|
14 |
+
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
|
15 |
+
model = Model().to(device)
|
16 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase1_scratch.pt",map_location = torch.device('cpu'))
|
17 |
+
model.load_state_dict(checkpoint,strict=False)
|
18 |
+
|
19 |
+
model.eval()
|
20 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
|
21 |
+
print("Computing Test Recall")
|
22 |
+
with torch.no_grad():
|
23 |
+
for (x_cl, x_cb, target, cl_fname, cb_fname) in tqdm(dataloader):
|
24 |
+
x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
|
25 |
+
x_cl, _ = model.get_embeddings(x_cl, ftype="contactless")
|
26 |
+
x_cb, _ = model.get_embeddings(x_cb, ftype="contactbased")
|
27 |
+
cl_feats_unnormed.append(x_cl.cpu().detach().numpy())
|
28 |
+
cb_feats_unnormed.append(x_cb.cpu().detach().numpy())
|
29 |
+
x_cl = l2_norm(x_cl).cpu().detach().numpy()
|
30 |
+
x_cb = l2_norm(x_cb).cpu().detach().numpy()
|
31 |
+
target = target.cpu().detach().numpy()
|
32 |
+
cl_feats.append(x_cl)
|
33 |
+
cb_feats.append(x_cb)
|
34 |
+
cl_labels.append(target)
|
35 |
+
cb_labels.append(target)
|
36 |
+
cl_fnames.extend(cl_fname)
|
37 |
+
cb_fnames.extend(cb_fname)
|
38 |
+
|
39 |
+
cl_feats = torch.from_numpy(np.concatenate(cl_feats))
|
40 |
+
cb_feats = torch.from_numpy(np.concatenate(cb_feats))
|
41 |
+
cl_labels = torch.from_numpy(np.concatenate(cl_labels))
|
42 |
+
cb_labels = torch.from_numpy(np.concatenate(cb_labels))
|
43 |
+
cl_feats_unnormed = torch.from_numpy(np.concatenate(cl_feats_unnormed))
|
44 |
+
cb_feats_unnormed = torch.from_numpy(np.concatenate(cb_feats_unnormed))
|
45 |
+
|
46 |
+
unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
|
47 |
+
unique_feats = torch.stack([cb_feats[indices == i].mean(dim=0) for i in range(len(unique_labels))])
|
48 |
+
cb_feats = unique_feats
|
49 |
+
unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
|
50 |
+
unique_feats = torch.stack([cb_feats_unnormed[indices == i].mean(dim=0) for i in range(len(unique_labels))])
|
51 |
+
cb_labels = unique_labels
|
52 |
+
cb_feats_unnormed = unique_feats
|
53 |
+
|
54 |
+
# CL2CB <---------------------------------------->
|
55 |
+
cl_feats = cl_feats.numpy()
|
56 |
+
cb_feats = cb_feats.numpy()
|
57 |
+
cb_feats_unnormed = cb_feats_unnormed.numpy()
|
58 |
+
cl_feats_unnormed = cl_feats_unnormed.numpy()
|
59 |
+
|
60 |
+
squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
|
61 |
+
distance = -1 * np.sqrt(squared_diff)
|
62 |
+
similarities = np.dot(cl_feats,np.transpose(cb_feats))
|
63 |
+
scores_mat = similarities + 0.1 * distance
|
64 |
+
scores = scores_mat.flatten().tolist()
|
65 |
+
|
66 |
+
ids = torch.eq(cl_labels.view(-1,1)-cb_labels.view(1,-1),0.0).flatten().tolist()
|
67 |
+
ids_mod = list()
|
68 |
+
for x in ids:
|
69 |
+
if x==True:
|
70 |
+
ids_mod.append(1)
|
71 |
+
else:
|
72 |
+
ids_mod.append(0)
|
73 |
+
fpr,tpr,thresh = roc_curve(ids_mod,scores,drop_intermediate=True)
|
74 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
75 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
76 |
+
tar_far_102 = tpr[upper_fpr_idx]
|
77 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
|
78 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
|
79 |
+
|
80 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
81 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
82 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
83 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
84 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
85 |
+
|
86 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
87 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
88 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
89 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
90 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
91 |
+
|
92 |
+
fnr = 1 - tpr
|
93 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
94 |
+
roc_auc = auc(fpr, tpr)
|
95 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
96 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
97 |
+
eer_cb2cl = EER * 100
|
98 |
+
cbcltf102 = tar_far_102 * 100
|
99 |
+
cbcltf103 = tar_far_103 * 100
|
100 |
+
cbcltf104 = tar_far_104 * 100
|
101 |
+
cl_labels = cl_labels.cpu().detach()
|
102 |
+
cb_labels = cb_labels.cpu().detach()
|
103 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
104 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
105 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
106 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 1) * 100} %")
|
107 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 10) * 100} %")
|
108 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 50) * 100} %")
|
109 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 100) * 100} %")
|
110 |
+
|
111 |
+
################################################################################
|
112 |
+
|
113 |
+
# CL2CL
|
114 |
+
scores = torch.from_numpy(np.dot(cl_feats,np.transpose(cl_feats)))
|
115 |
+
row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
|
116 |
+
scores = scores[row, col]
|
117 |
+
scores = scores.numpy().flatten().tolist()
|
118 |
+
labels = torch.eq(cl_labels.view(-1,1) - cl_labels.view(1,-1),0.0).float().cuda()
|
119 |
+
labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1].tolist()
|
120 |
+
fpr,tpr,_ = roc_curve(labels,scores)
|
121 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
122 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
123 |
+
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
124 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
125 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
126 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
127 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
128 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
129 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
130 |
+
clcltf102 = tar_far_102 * 100
|
131 |
+
clcltf103 = tar_far_103 * 100
|
132 |
+
clcltf104 = tar_far_104 * 100
|
133 |
+
fnr = 1 - tpr
|
134 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
135 |
+
roc_auc = auc(fpr, tpr)
|
136 |
+
print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
|
137 |
+
print(f"EER for CL2CL: {EER * 100} %")
|
138 |
+
eer_cl2cl = EER * 100
|
139 |
+
print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
|
140 |
+
print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
|
141 |
+
print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
|
142 |
+
cl_labels = cl_labels.cpu().detach().numpy()
|
143 |
+
recall_score = Prev_RetMetric([cl_feats,cl_feats],[cl_labels,cl_labels],cl2cl = True)
|
144 |
+
cl2clk1 = recall_score.recall_k(k=1) * 100
|
145 |
+
print(f"R@1 for CL2CL: {recall_score.recall_k(k=1) * 100} %")
|
146 |
+
print(f"R@10 for CL2CL: {recall_score.recall_k(k=10) * 100} %")
|
147 |
+
print(f"R@50 for CL2CL: {recall_score.recall_k(k=50) * 100} %")
|
148 |
+
print(f"R@100 for CL2CL: {recall_score.recall_k(k=100) * 100} %")
|
rb_evaluation_phase2.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from datasets.rb_loader_cl import RB_loader_cl
|
3 |
+
from datasets.rb_loader_cb import RB_loader_cb
|
4 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from model import SwinModel_Fusion as Model
|
8 |
+
from sklearn.metrics import roc_curve, auc
|
9 |
+
import json
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
|
13 |
+
cl_tokens = torch.cat(cl_tokens)
|
14 |
+
cb_tokens = torch.cat(cb_tokens)
|
15 |
+
|
16 |
+
batch_size_cl = cl_tokens.shape[0]
|
17 |
+
batch_size_cb = cb_tokens.shape[0]
|
18 |
+
shard_size = 20
|
19 |
+
similarity_matrix = torch.zeros((batch_size_cl, batch_size_cb))
|
20 |
+
for i_start in tqdm(range(0, batch_size_cl, shard_size)):
|
21 |
+
i_end = min(i_start + shard_size, batch_size_cl)
|
22 |
+
shard_i = cl_tokens[i_start:i_end]
|
23 |
+
for j_start in range(0, batch_size_cb, shard_size):
|
24 |
+
j_end = min(j_start + shard_size, batch_size_cb)
|
25 |
+
shard_j = cb_tokens[j_start:j_end]
|
26 |
+
batch_i = shard_i.unsqueeze(1)
|
27 |
+
batch_j = shard_j.unsqueeze(0)
|
28 |
+
|
29 |
+
pairwise_i = batch_i.expand(-1, shard_j.shape[0], -1, -1)
|
30 |
+
pairwise_j = batch_j.expand(shard_i.shape[0], -1, -1, -1)
|
31 |
+
|
32 |
+
similarity_scores, distances = model.combine_features(
|
33 |
+
pairwise_i.reshape(-1, 197, shard_i.shape[-1]),
|
34 |
+
pairwise_j.reshape(-1, 197, shard_j.shape[-1])
|
35 |
+
)
|
36 |
+
scores = similarity_scores - 0.1 * distances #-0.1
|
37 |
+
scores = scores.reshape(shard_i.shape[0], shard_j.shape[0])
|
38 |
+
similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
|
39 |
+
return similarity_matrix
|
40 |
+
|
41 |
+
device = torch.device('cuda')
|
42 |
+
data_cl = RB_loader_cl(split="test")
|
43 |
+
data_cb = RB_loader_cb(split="test")
|
44 |
+
dataloader_cb = torch.utils.data.DataLoader(data_cb,batch_size = 16, num_workers = 1, pin_memory = True)
|
45 |
+
dataloader_cl = torch.utils.data.DataLoader(data_cl,batch_size = 16, num_workers = 1, pin_memory = True)
|
46 |
+
model = Model().to(device)
|
47 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase2_scratch.pt",map_location = torch.device('cpu'))
|
48 |
+
model.load_state_dict(checkpoint,strict=False)
|
49 |
+
|
50 |
+
model.eval()
|
51 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
|
52 |
+
print("Computing Test Recall")
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
for (x_cb, target) in tqdm(dataloader_cb):
|
56 |
+
x_cb, label = x_cb.to(device), target.to(device)
|
57 |
+
x_cb_token = model.get_tokens(x_cb,'contactbased')
|
58 |
+
label = label.cpu().detach().numpy()
|
59 |
+
cb_feats.append(x_cb_token)
|
60 |
+
cb_labels.append(label)
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
for (x_cl, target) in tqdm(dataloader_cl):
|
64 |
+
x_cl, label = x_cl.to(device), target.to(device)
|
65 |
+
x_cl_token = model.get_tokens(x_cl,'contactless')
|
66 |
+
label = label.cpu().detach().numpy()
|
67 |
+
cl_feats.append(x_cl_token)
|
68 |
+
cl_labels.append(label)
|
69 |
+
|
70 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
71 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
72 |
+
|
73 |
+
# CB2CL <---------------------------------------->
|
74 |
+
scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
|
75 |
+
scores = scores_mat.cpu().detach().numpy().flatten().tolist()
|
76 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
77 |
+
ids_mod = list()
|
78 |
+
for i in labels:
|
79 |
+
if i==True:
|
80 |
+
ids_mod.append(1)
|
81 |
+
else:
|
82 |
+
ids_mod.append(0)
|
83 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
84 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
85 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
86 |
+
tar_far_102 = tpr[upper_fpr_idx]#(tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
87 |
+
|
88 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
89 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
90 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
91 |
+
|
92 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
93 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
94 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
95 |
+
|
96 |
+
fnr = 1 - tpr
|
97 |
+
|
98 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
99 |
+
roc_auc = auc(fpr, tpr)
|
100 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
101 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
102 |
+
eer_cb2cl = EER * 100
|
103 |
+
cbcltf102 = tar_far_102 * 100
|
104 |
+
cbcltf103 = tar_far_103 * 100
|
105 |
+
cbcltf104 = tar_far_104 * 100
|
106 |
+
cl_label = cl_label.cpu().detach()
|
107 |
+
cb_label = cb_label.cpu().detach()
|
108 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
109 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
110 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
111 |
+
|
112 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
|
113 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
|
114 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
|
115 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
|
116 |
+
|
117 |
+
# CL2CL -------------------------
|
118 |
+
scores = get_fused_cross_score_matrix(model, cl_feats, cl_feats)
|
119 |
+
scores_mat = scores
|
120 |
+
row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
|
121 |
+
scores = scores[row, col]
|
122 |
+
labels = torch.eq(cl_label.view(-1,1) - cl_label.view(1,-1),0.0).float().cuda()
|
123 |
+
labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1]
|
124 |
+
scores = scores.cpu().detach().numpy().flatten().tolist()
|
125 |
+
labels = labels.flatten().tolist()
|
126 |
+
ids_mod = list()
|
127 |
+
for i in labels:
|
128 |
+
if i==True:
|
129 |
+
ids_mod.append(1)
|
130 |
+
else:
|
131 |
+
ids_mod.append(0)
|
132 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
133 |
+
|
134 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
135 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
136 |
+
tar_far_102 = tpr[upper_fpr_idx]#(tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
137 |
+
|
138 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
139 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
140 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
141 |
+
|
142 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
143 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
144 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
145 |
+
|
146 |
+
fnr = 1 - tpr
|
147 |
+
|
148 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
149 |
+
roc_auc = auc(fpr, tpr)
|
150 |
+
print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
|
151 |
+
print(f"EER for CL2CL: {EER * 100} %")
|
152 |
+
eer_cb2cl = EER * 100
|
153 |
+
cbcltf102 = tar_far_102 * 100
|
154 |
+
cbcltf103 = tar_far_103 * 100
|
155 |
+
cbcltf104 = tar_far_104 * 100
|
156 |
+
cl_label = cl_label.cpu().detach()
|
157 |
+
print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
|
158 |
+
print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
|
159 |
+
print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
|
160 |
+
|
161 |
+
print(f"R@1 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 1) * 100} %")
|
162 |
+
print(f"R@10 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 10) * 100} %")
|
163 |
+
print(f"R@50 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 50) * 100} %")
|
164 |
+
print(f"R@100 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 100) * 100} %")
|
requirements.txt
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may be used to create an environment using:
|
2 |
+
# $ conda create --name <env> --file <this file>
|
3 |
+
# platform: linux-64
|
4 |
+
_libgcc_mutex=0.1=main
|
5 |
+
_openmp_mutex=5.1=1_gnu
|
6 |
+
absl-py=2.1.0=pypi_0
|
7 |
+
addict=2.4.0=pypi_0
|
8 |
+
aliyun-python-sdk-core=2.15.0=pypi_0
|
9 |
+
aliyun-python-sdk-kms=2.16.2=pypi_0
|
10 |
+
attrs=23.2.0=pypi_0
|
11 |
+
blas=1.0=mkl
|
12 |
+
bzip2=1.0.8=h5eee18b_6
|
13 |
+
ca-certificates=2024.7.2=h06a4308_0
|
14 |
+
cachetools=5.4.0=pypi_0
|
15 |
+
certifi=2024.2.2=pypi_0
|
16 |
+
cffi=1.16.0=pypi_0
|
17 |
+
charset-normalizer=2.1.1=pypi_0
|
18 |
+
click=8.1.7=pypi_0
|
19 |
+
colorama=0.4.6=pypi_0
|
20 |
+
coloredlogs=15.0.1=pypi_0
|
21 |
+
contourpy=1.1.1=pypi_0
|
22 |
+
crcmod=1.7=pypi_0
|
23 |
+
cryptography=42.0.5=pypi_0
|
24 |
+
cuda-cudart=11.8.89=0
|
25 |
+
cuda-cudart_linux-64=12.4.127=hd681fbe_0
|
26 |
+
cuda-cupti=11.8.87=0
|
27 |
+
cuda-libraries=11.8.0=0
|
28 |
+
cuda-nvrtc=11.8.89=0
|
29 |
+
cuda-nvtx=11.8.86=0
|
30 |
+
cuda-opencl=12.4.127=h6a678d5_0
|
31 |
+
cuda-runtime=11.8.0=0
|
32 |
+
cuda-version=12.4=hbda6634_3
|
33 |
+
cycler=0.12.1=pypi_0
|
34 |
+
entrypoints=0.4=pypi_0
|
35 |
+
ffmpeg=4.3=hf484d3e_0
|
36 |
+
flatbuffers=24.3.25=pypi_0
|
37 |
+
fonttools=4.53.1=pypi_0
|
38 |
+
freetype=2.12.1=h4a9f257_0
|
39 |
+
fsspec=2024.3.1=pypi_0
|
40 |
+
gmp=6.2.1=h295c915_3
|
41 |
+
gnutls=3.6.15=he1e5248_0
|
42 |
+
google-auth=2.33.0=pypi_0
|
43 |
+
google-auth-oauthlib=1.0.0=pypi_0
|
44 |
+
grpcio=1.65.4=pypi_0
|
45 |
+
httpcore=1.0.5=pypi_0
|
46 |
+
httpx=0.27.0=pypi_0
|
47 |
+
huggingface-hub=0.22.1=pypi_0
|
48 |
+
humanfriendly=10.0=pypi_0
|
49 |
+
idna=3.6=pypi_0
|
50 |
+
imageio=2.34.2=pypi_0
|
51 |
+
importlib-metadata=7.1.0=pypi_0
|
52 |
+
importlib-resources=6.4.0=pypi_0
|
53 |
+
intel-openmp=2023.1.0=hdb19cb5_46306
|
54 |
+
jinja2=3.1.3=pypi_0
|
55 |
+
jmespath=0.10.0=pypi_0
|
56 |
+
joblib=1.4.2=pypi_0
|
57 |
+
jpeg=9e=h5eee18b_2
|
58 |
+
jsonschema=4.21.1=pypi_0
|
59 |
+
jsonschema-specifications=2023.12.1=pypi_0
|
60 |
+
kaleido=0.2.1=pypi_0
|
61 |
+
kiwisolver=1.4.5=pypi_0
|
62 |
+
lame=3.100=h7b6447c_0
|
63 |
+
lcms2=2.12=h3be6417_0
|
64 |
+
ld_impl_linux-64=2.38=h1181459_1
|
65 |
+
lerc=3.0=h295c915_0
|
66 |
+
libcublas=11.11.3.6=0
|
67 |
+
libcufft=10.9.0.58=0
|
68 |
+
libcufile=1.9.1.3=h99ab3db_1
|
69 |
+
libcurand=10.3.5.147=h99ab3db_1
|
70 |
+
libcusolver=11.4.1.48=0
|
71 |
+
libcusparse=11.7.5.86=0
|
72 |
+
libdeflate=1.17=h5eee18b_1
|
73 |
+
libffi=3.4.4=h6a678d5_1
|
74 |
+
libgcc-ng=11.2.0=h1234567_1
|
75 |
+
libgomp=11.2.0=h1234567_1
|
76 |
+
libiconv=1.16=h5eee18b_3
|
77 |
+
libidn2=2.3.4=h5eee18b_0
|
78 |
+
libjpeg-turbo=2.0.0=h9bf148f_0
|
79 |
+
libnpp=11.8.0.86=0
|
80 |
+
libnvfatbin=12.4.127=h7934f7d_2
|
81 |
+
libnvjitlink=12.4.99=0
|
82 |
+
libnvjpeg=11.9.0.86=0
|
83 |
+
libpng=1.6.39=h5eee18b_0
|
84 |
+
libstdcxx-ng=11.2.0=h1234567_1
|
85 |
+
libtasn1=4.19.0=h5eee18b_0
|
86 |
+
libtiff=4.5.1=h6a678d5_0
|
87 |
+
libunistring=0.9.10=h27cfd23_0
|
88 |
+
libwebp-base=1.3.2=h5eee18b_0
|
89 |
+
llvm-openmp=14.0.6=h9e868ea_0
|
90 |
+
llvmlite=0.41.1=pypi_0
|
91 |
+
lz4-c=1.9.4=h6a678d5_1
|
92 |
+
markdown=3.6=pypi_0
|
93 |
+
markdown-it-py=3.0.0=pypi_0
|
94 |
+
markupsafe=2.1.5=pypi_0
|
95 |
+
matplotlib=3.7.5=pypi_0
|
96 |
+
mdit-py-plugins=0.4.0=pypi_0
|
97 |
+
mkl=2023.1.0=h213fc3f_46344
|
98 |
+
mmcv=2.1.0=dev_0
|
99 |
+
mmdet=3.3.0=dev_0
|
100 |
+
mmengine=0.10.3=pypi_0
|
101 |
+
model-index=0.1.11=pypi_0
|
102 |
+
mpc=1.1.0=h10f8cd9_1
|
103 |
+
mpfr=4.0.2=hb69a4c5_1
|
104 |
+
mpmath=1.3.0=py38h06a4308_0
|
105 |
+
ncurses=6.4=h6a678d5_0
|
106 |
+
nettle=3.7.3=hbbd107a_1
|
107 |
+
networkx=3.1=py38h06a4308_0
|
108 |
+
numba=0.58.1=pypi_0
|
109 |
+
numpy=1.24.4=pypi_0
|
110 |
+
nvidia-cublas-cu11=11.11.3.6=pypi_0
|
111 |
+
nvidia-cuda-cupti-cu11=11.8.87=pypi_0
|
112 |
+
nvidia-cuda-nvrtc-cu11=11.8.89=pypi_0
|
113 |
+
nvidia-cuda-runtime-cu11=11.8.89=pypi_0
|
114 |
+
nvidia-cudnn-cu11=8.7.0.84=pypi_0
|
115 |
+
nvidia-cufft-cu11=10.9.0.58=pypi_0
|
116 |
+
nvidia-curand-cu11=10.3.0.86=pypi_0
|
117 |
+
nvidia-cusolver-cu11=11.4.1.48=pypi_0
|
118 |
+
nvidia-cusparse-cu11=11.7.5.86=pypi_0
|
119 |
+
nvidia-nccl-cu11=2.19.3=pypi_0
|
120 |
+
nvidia-nvtx-cu11=11.8.86=pypi_0
|
121 |
+
oauthlib=3.2.2=pypi_0
|
122 |
+
ocl-icd=2.3.2=h5eee18b_1
|
123 |
+
onnxruntime=1.18.1=pypi_0
|
124 |
+
opencv-python=4.10.0.84=pypi_0
|
125 |
+
opencv-python-headless=4.10.0.84=pypi_0
|
126 |
+
opendatalab=0.0.10=pypi_0
|
127 |
+
openh264=2.1.1=h4ff587b_0
|
128 |
+
openjpeg=2.4.0=h9ca470c_2
|
129 |
+
openmim=0.3.9=pypi_0
|
130 |
+
openssl=3.0.14=h5eee18b_0
|
131 |
+
openxlab=0.0.37=pypi_0
|
132 |
+
ordered-set=4.1.0=pypi_0
|
133 |
+
oss2=2.17.0=pypi_0
|
134 |
+
packaging=24.0=pypi_0
|
135 |
+
pandas=2.0.3=pypi_0
|
136 |
+
pillow=9.0.1=pypi_0
|
137 |
+
pip=23.3.1=pypi_0
|
138 |
+
pkgutil-resolve-name=1.3.10=pypi_0
|
139 |
+
platformdirs=4.2.0=pypi_0
|
140 |
+
plotly=5.23.0=pypi_0
|
141 |
+
pooch=1.8.2=pypi_0
|
142 |
+
protobuf=5.27.3=pypi_0
|
143 |
+
pyasn1=0.6.0=pypi_0
|
144 |
+
pyasn1-modules=0.4.0=pypi_0
|
145 |
+
pycocotools=2.0.7=pypi_0
|
146 |
+
pycparser=2.21=pypi_0
|
147 |
+
pygments=2.17.2=pypi_0
|
148 |
+
pymatting=1.1.12=pypi_0
|
149 |
+
pyparsing=3.1.2=pypi_0
|
150 |
+
python=3.8.19=h955ad1f_0
|
151 |
+
python-dateutil=2.9.0.post0=pypi_0
|
152 |
+
pytorch-cuda=11.8=h7e8668a_5
|
153 |
+
pytorch-metric-learning=2.5.0=pypi_0
|
154 |
+
pytorch-mutex=1.0=cuda
|
155 |
+
pytz=2023.4=pypi_0
|
156 |
+
pywavelets=1.4.1=pypi_0
|
157 |
+
pyyaml=6.0.1=py38h5eee18b_0
|
158 |
+
readline=8.2=h5eee18b_0
|
159 |
+
referencing=0.34.0=pypi_0
|
160 |
+
rembg=2.0.58=pypi_0
|
161 |
+
requests=2.28.2=pypi_0
|
162 |
+
requests-oauthlib=2.0.0=pypi_0
|
163 |
+
rich=13.4.2=pypi_0
|
164 |
+
rpds-py=0.18.0=pypi_0
|
165 |
+
rsa=4.9=pypi_0
|
166 |
+
safetensors=0.4.2=pypi_0
|
167 |
+
scikit-image=0.19.3=pypi_0
|
168 |
+
scikit-learn=1.3.2=pypi_0
|
169 |
+
scipy=1.10.1=pypi_0
|
170 |
+
setuptools=60.2.0=pypi_0
|
171 |
+
shapely=2.0.3=pypi_0
|
172 |
+
six=1.16.0=pypi_0
|
173 |
+
sqlite=3.45.3=h5eee18b_0
|
174 |
+
sympy=1.12=py38h06a4308_0
|
175 |
+
tabulate=0.9.0=pypi_0
|
176 |
+
tbb=2021.8.0=hdb19cb5_0
|
177 |
+
tenacity=9.0.0=pypi_0
|
178 |
+
tensorboard=2.14.0=pypi_0
|
179 |
+
tensorboard-data-server=0.7.2=pypi_0
|
180 |
+
termcolor=2.4.0=pypi_0
|
181 |
+
terminaltables=3.1.10=pypi_0
|
182 |
+
threadpoolctl=3.5.0=pypi_0
|
183 |
+
tifffile=2023.7.10=pypi_0
|
184 |
+
timm=0.5.0=dev_0
|
185 |
+
tk=8.6.14=h39e8969_0
|
186 |
+
tomli=2.0.1=pypi_0
|
187 |
+
torch=2.2.2+cu118=pypi_0
|
188 |
+
torchaudio=2.2.2+cu118=pypi_0
|
189 |
+
torchvision=0.17.2+cu118=pypi_0
|
190 |
+
tqdm=4.66.5=pypi_0
|
191 |
+
triton=2.2.0=pypi_0
|
192 |
+
typing-extensions=4.10.0=pypi_0
|
193 |
+
tzdata=2024.1=pypi_0
|
194 |
+
urllib3=1.26.18=pypi_0
|
195 |
+
werkzeug=3.0.3=pypi_0
|
196 |
+
wheel=0.41.2=pypi_0
|
197 |
+
xz=5.4.6=h5eee18b_1
|
198 |
+
yaml=0.2.5=h7b6447c_0
|
199 |
+
yapf=0.40.2=pypi_0
|
200 |
+
zipp=3.18.1=pypi_0
|
201 |
+
zlib=1.2.13=h5eee18b_1
|
202 |
+
zstd=1.5.5=hc292b87_2
|
train_combined.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from torchvision import datasets, transforms
|
9 |
+
from torch.optim.lr_scheduler import StepLR, MultiStepLR
|
10 |
+
from datasets.hkpoly_test import hktest
|
11 |
+
from datasets.original_combined_train import Combined_original
|
12 |
+
from datasets.rb_loader import RB_loader
|
13 |
+
from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
|
14 |
+
import timm
|
15 |
+
from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
|
16 |
+
from pprint import pprint
|
17 |
+
import numpy as np
|
18 |
+
from tqdm import tqdm
|
19 |
+
from combined_sampler import BalancedSampler
|
20 |
+
from torch.utils.data.sampler import BatchSampler
|
21 |
+
from torch.nn.parallel import DataParallel
|
22 |
+
from model import SwinModel_domain_agnostic as Model
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
from sklearn.metrics import roc_curve, auc
|
25 |
+
import json
|
26 |
+
from torch.utils.tensorboard import SummaryWriter
|
27 |
+
|
28 |
+
def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer):
|
29 |
+
model.train()
|
30 |
+
steploss = list()
|
31 |
+
for batch_idx, (x_cl, x_cb, target, category_cl, category_cb) in enumerate(pbar := tqdm(train_loader)):
|
32 |
+
x_cl, x_cb, target, category_cl, category_cb = x_cl.to(device), x_cb.to(device), target.to(device), category_cl.to(device), category_cb.to(device)
|
33 |
+
for optimizer in optimizers:
|
34 |
+
optimizer.zero_grad()
|
35 |
+
x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb = model(x_cl, x_cb)
|
36 |
+
loss = loss_func(x_cl, x_cb, x_cl_tokens, x_cb_tokens, target, device, domain_class_cl, domain_class_cb, category_cl, category_cb)
|
37 |
+
loss.backward()
|
38 |
+
for optimizer in optimizers:
|
39 |
+
optimizer.step()
|
40 |
+
if batch_idx % args.log_interval == 0:
|
41 |
+
if args.dry_run:
|
42 |
+
break
|
43 |
+
pbar.set_description(f"Loss {loss}")
|
44 |
+
steploss.append(loss)
|
45 |
+
return sum(steploss)/len(steploss), stepping
|
46 |
+
|
47 |
+
def l2_norm(input):
|
48 |
+
input_size = input.size()
|
49 |
+
buffer = torch.pow(input, 2)
|
50 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
51 |
+
norm = torch.sqrt(normp)
|
52 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
53 |
+
output = _output.view(input_size)
|
54 |
+
return output
|
55 |
+
|
56 |
+
def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
|
57 |
+
model.eval()
|
58 |
+
cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
|
59 |
+
with torch.no_grad():
|
60 |
+
for (x_cl, x_cb, label) in tqdm(test_loader):
|
61 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
62 |
+
x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
|
63 |
+
x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
|
64 |
+
x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
|
65 |
+
x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
|
66 |
+
label = label.cpu().detach().numpy()
|
67 |
+
cl_feats.append(x_cl_feat)
|
68 |
+
cb_feats.append(x_cb_feat)
|
69 |
+
cl_labels.append(label)
|
70 |
+
cb_labels.append(label)
|
71 |
+
|
72 |
+
cl_feats = np.concatenate(cl_feats)
|
73 |
+
cb_feats = np.concatenate(cb_feats)
|
74 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
75 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
76 |
+
|
77 |
+
# CB2CL
|
78 |
+
scores = np.dot(cl_feats,np.transpose(cb_feats))
|
79 |
+
np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+".npy", scores)
|
80 |
+
scores = scores.flatten().tolist()
|
81 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
82 |
+
ids_mod = list()
|
83 |
+
for i in labels:
|
84 |
+
if i==True:
|
85 |
+
ids_mod.append(1)
|
86 |
+
else:
|
87 |
+
ids_mod.append(0)
|
88 |
+
fpr,tpr,_ = roc_curve(labels,scores)
|
89 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
90 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
91 |
+
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
92 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
93 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
94 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
95 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
96 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
97 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
98 |
+
fnr = 1 - tpr
|
99 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
100 |
+
roc_auc = auc(fpr, tpr)
|
101 |
+
plt.figure()
|
102 |
+
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
|
103 |
+
plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
|
104 |
+
plt.xlim([0, 1])
|
105 |
+
plt.ylim([0, 1])
|
106 |
+
plt.xlabel('False Positive Rate')
|
107 |
+
plt.ylabel('True Positive Rate')
|
108 |
+
plt.title('ROC Curve CB2CL task1')
|
109 |
+
plt.legend(loc="lower right")
|
110 |
+
plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+str(epoch)+".png", dpi=300, bbox_inches='tight')
|
111 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
112 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
113 |
+
eer_cb2cl = EER * 100
|
114 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
115 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
116 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
117 |
+
cbcltf102 = tar_far_102 * 100
|
118 |
+
cbcltf103 = tar_far_103 * 100
|
119 |
+
cbcltf104 = tar_far_104 * 100
|
120 |
+
cl_label = cl_label.cpu().detach().numpy()
|
121 |
+
cb_label = cb_label.cpu().detach().numpy()
|
122 |
+
recall_score = Prev_RetMetric([cb_feats,cl_feats],[cb_label,cl_label],cl2cl = False)
|
123 |
+
cl2cbk1 = recall_score.recall_k(k=1) * 100
|
124 |
+
print(f"R@1 for CB2CL: {recall_score.recall_k(k=1) * 100} %")
|
125 |
+
print(f"R@10 for CB2CL: {recall_score.recall_k(k=10) * 100} %")
|
126 |
+
print(f"R@50 for CB2CL: {recall_score.recall_k(k=50) * 100} %")
|
127 |
+
print(f"R@100 for CB2CL: {recall_score.recall_k(k=100) * 100} %")
|
128 |
+
|
129 |
+
return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
|
130 |
+
|
131 |
+
def main():
|
132 |
+
# Training settings
|
133 |
+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
134 |
+
parser.add_argument('--manifest-list', type=list, default=mani_lst,
|
135 |
+
help='list of manifest files from different datasets to train on')
|
136 |
+
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
137 |
+
help='input batch size for training (default: 64)')
|
138 |
+
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
|
139 |
+
help='input batch size for testing (default: 1000)')
|
140 |
+
parser.add_argument('--epochs', type=int, default=50, metavar='N',
|
141 |
+
help='number of epochs to train (default: 14)')
|
142 |
+
parser.add_argument('--lr_linear', type=float, default=1.0, metavar='LR',
|
143 |
+
help='learning rate (default: 1.0)')
|
144 |
+
parser.add_argument('--lr_swin', type=float, default=1.0, metavar='LR',
|
145 |
+
help='learning rate (default: 1.0)')
|
146 |
+
parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
|
147 |
+
help='Learning rate step gamma (default: 0.7)')
|
148 |
+
parser.add_argument('--no-cuda', action='store_true', default=False,
|
149 |
+
help='disables CUDA training')
|
150 |
+
parser.add_argument('--dry-run', action='store_true', default=False,
|
151 |
+
help='quickly check a single pass')
|
152 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
153 |
+
help='random seed (default: 1)')
|
154 |
+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
155 |
+
help='how many batches to wait before logging training status')
|
156 |
+
parser.add_argument('--warmup', type=int, default=2, metavar='N',
|
157 |
+
help='warm up rate for feature extractor')
|
158 |
+
parser.add_argument('--model-name', type=str, default="ridgeformer",
|
159 |
+
help='Name of the model for checkpointing')
|
160 |
+
args = parser.parse_args()
|
161 |
+
|
162 |
+
checkpoint_save_path = "ridgeformer_checkpoints/"
|
163 |
+
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
164 |
+
|
165 |
+
if not os.path.exists("experiment_logs/"+args.model_name):
|
166 |
+
os.mkdir("experiment_logs/"+args.model_name)
|
167 |
+
|
168 |
+
log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_linear)+str(args.lr_swin))
|
169 |
+
|
170 |
+
torch.manual_seed(args.seed)
|
171 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
172 |
+
|
173 |
+
print("loading Normal RGB images -----------------------------")
|
174 |
+
train_dataset = Combined_original(args.manifest_list,split="train")
|
175 |
+
val_dataset = hktest(split="test")
|
176 |
+
|
177 |
+
balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
|
178 |
+
batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
|
179 |
+
|
180 |
+
train_kwargs = {'batch_sampler': batch_sampler}
|
181 |
+
test_kwargs = {'batch_size': args.test_batch_size}
|
182 |
+
|
183 |
+
if use_cuda:
|
184 |
+
cuda_kwargs = {
|
185 |
+
'num_workers': 1,
|
186 |
+
'pin_memory': True
|
187 |
+
}
|
188 |
+
train_kwargs.update(cuda_kwargs)
|
189 |
+
test_kwargs.update(cuda_kwargs)
|
190 |
+
|
191 |
+
|
192 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
|
193 |
+
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
|
194 |
+
|
195 |
+
model = Model().to(device)
|
196 |
+
ckpt = torch.load("ridgeformer_checkpoints/phase1_scratch.pt", map_location=torch.device('cpu'))
|
197 |
+
model.load_state_dict(ckpt,strict=False)
|
198 |
+
print("Number of Trainable Parameters: - ", count_parameters(model))
|
199 |
+
|
200 |
+
loss_func = DualMSLoss_FineGrained_domain_agnostic()
|
201 |
+
# loss_func = DualMSLoss_FineGrained_domain_agnostic_ft()
|
202 |
+
|
203 |
+
optimizer_swin = optim.AdamW(
|
204 |
+
[
|
205 |
+
{"params": model.swin_cl.parameters(), "lr":args.lr_swin},
|
206 |
+
{"params": model.classify.parameters(), "lr":args.lr_linear},
|
207 |
+
{"params": model.linear_cl.parameters(), "lr":args.lr_linear},
|
208 |
+
{"params": model.linear_cb.parameters(), "lr":args.lr_linear},
|
209 |
+
],
|
210 |
+
weight_decay=0.000001,
|
211 |
+
lr=args.lr_swin)
|
212 |
+
|
213 |
+
scheduler_swin = MultiStepLR(optimizer_swin, milestones = [100], gamma=0.7)
|
214 |
+
|
215 |
+
cl2cl_lst = list()
|
216 |
+
cb2cl_lst = list()
|
217 |
+
eer_cl2cl_lst = list()
|
218 |
+
eer_cb2cl_lst = list()
|
219 |
+
cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list()
|
220 |
+
stepping = 1
|
221 |
+
for epoch in range(1, args.epochs + 1):
|
222 |
+
print(f"running epoch------ {epoch}")
|
223 |
+
if (epoch > args.warmup):
|
224 |
+
print("Training with Swin")
|
225 |
+
model.unfreeze_encoder()
|
226 |
+
else:
|
227 |
+
print("Training only linear")
|
228 |
+
model.freeze_encoder()
|
229 |
+
|
230 |
+
avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_swin], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)],stepping,log_writer)
|
231 |
+
|
232 |
+
print(f"Learning Rate for {epoch} for swin = {scheduler_swin.get_last_lr()}")
|
233 |
+
|
234 |
+
log_writer.add_scalar('Swin_LR/epoch',scheduler_swin.get_last_lr()[0],epoch)
|
235 |
+
|
236 |
+
if (epoch > args.warmup):
|
237 |
+
scheduler_swin.step()
|
238 |
+
|
239 |
+
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch,[args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)])
|
240 |
+
cl2cl_lst.append(cl2clk1)
|
241 |
+
cb2cl_lst.append(cl2cbk1)
|
242 |
+
eer_cl2cl_lst.append(eer_cl2cl)
|
243 |
+
eer_cb2cl_lst.append(eer_cb2cl)
|
244 |
+
cbcltf102_lst.append(cbcltf102)
|
245 |
+
cbcltf103_lst.append(cbcltf103)
|
246 |
+
cbcltf104_lst.append(cbcltf104)
|
247 |
+
clcltf102_lst.append(clcltf102)
|
248 |
+
clcltf103_lst.append(clcltf103)
|
249 |
+
clcltf104_lst.append(clcltf104)
|
250 |
+
|
251 |
+
log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
|
252 |
+
log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
|
253 |
+
log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
|
254 |
+
log_writer.add_scalars('TARFAR10^-3/epoch',{'CL2CL':clcltf103,'CB2CL':cbcltf103},epoch)
|
255 |
+
log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
|
256 |
+
log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
|
257 |
+
|
258 |
+
torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_linear) + "_" + str(args.lr_swin) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
|
259 |
+
log_writer.close()
|
260 |
+
|
261 |
+
print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
|
262 |
+
print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
|
263 |
+
print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
|
264 |
+
print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
|
265 |
+
print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
|
266 |
+
print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
|
267 |
+
print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
|
268 |
+
print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
|
269 |
+
print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
|
270 |
+
print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
|
271 |
+
|
272 |
+
if __name__ == '__main__':
|
273 |
+
main()
|
train_combined_fusion.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from torchvision import datasets, transforms
|
9 |
+
from torch.optim.lr_scheduler import StepLR, MultiStepLR
|
10 |
+
from datasets.hkpoly_test import hktest
|
11 |
+
from datasets.original_combined_train import Combined_original
|
12 |
+
from datasets.rb_loader import RB_loader
|
13 |
+
from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
|
14 |
+
import timm
|
15 |
+
from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
|
16 |
+
from pprint import pprint
|
17 |
+
import numpy as np
|
18 |
+
from tqdm import tqdm
|
19 |
+
from combined_sampler import BalancedSampler
|
20 |
+
from torch.utils.data.sampler import BatchSampler
|
21 |
+
from torch.nn.parallel import DataParallel
|
22 |
+
from model import SwinModel_Fusion as Model
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
from sklearn.metrics import roc_curve, auc
|
25 |
+
import json
|
26 |
+
from torch.utils.tensorboard import SummaryWriter
|
27 |
+
|
28 |
+
def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer, checkpoint_save_path):
|
29 |
+
model.train()
|
30 |
+
steploss = list()
|
31 |
+
for batch_idx, (x_cl, x_cb, target,_,_) in enumerate(pbar := tqdm(train_loader)):
|
32 |
+
x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
|
33 |
+
for optimizer in optimizers:
|
34 |
+
optimizer.zero_grad()
|
35 |
+
x_cl_tokens, x_cb_tokens = model(x_cl, x_cb)
|
36 |
+
|
37 |
+
N, M, D = x_cl_tokens.shape
|
38 |
+
|
39 |
+
index_i = torch.arange(N).unsqueeze(1) # Shape: (100, 1)
|
40 |
+
index_j = torch.arange(N).unsqueeze(0) # Shape: (1, 100)
|
41 |
+
|
42 |
+
x = x_cl_tokens[index_i] # Shape: (100, 100, 197, 1024)
|
43 |
+
y = x_cb_tokens[index_j] # Shape: (100, 100, 197, 1024)
|
44 |
+
|
45 |
+
x = x.expand(N, N, M, D).reshape(N * N, M, D) # Shape: (10000, 197, 1024)
|
46 |
+
y = y.expand(N, N, M, D).reshape(N * N, M, D) # Shape: (10000, 197, 1024)
|
47 |
+
sim_matrix,_ = model.combine_features(x, y)
|
48 |
+
sim_matrix = sim_matrix.view(N, N).to(device)
|
49 |
+
|
50 |
+
loss = loss_func.ms_sample(sim_matrix, target).cuda() + loss_func.ms_sample(sim_matrix.t(), target.t()).cuda()
|
51 |
+
loss.backward()
|
52 |
+
for optimizer in optimizers:
|
53 |
+
optimizer.step()
|
54 |
+
if batch_idx % args.log_interval == 0:
|
55 |
+
if args.dry_run:
|
56 |
+
break
|
57 |
+
pbar.set_description(f"Loss {loss}")
|
58 |
+
steploss.append(loss)
|
59 |
+
if (batch_idx + 1)%50 == 0:
|
60 |
+
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch, pl_arg)
|
61 |
+
log_writer.add_scalars('recall@1/step',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},stepping)
|
62 |
+
log_writer.add_scalars('EER/step',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},stepping)
|
63 |
+
log_writer.add_scalars('TARFAR10^-2/step',{'CL2CL':clcltf102,'CB2CL':cbcltf102},stepping)
|
64 |
+
log_writer.add_scalars('TARFAR10^-4/step',{'CL2CL':clcltf104,'CB2CL':cbcltf104},stepping)
|
65 |
+
stepping+=1
|
66 |
+
|
67 |
+
return sum(steploss)/len(steploss), stepping
|
68 |
+
|
69 |
+
def l2_norm(input):
|
70 |
+
input_size = input.size()
|
71 |
+
buffer = torch.pow(input, 2)
|
72 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
73 |
+
norm = torch.sqrt(normp)
|
74 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
75 |
+
output = _output.view(input_size)
|
76 |
+
return output
|
77 |
+
|
78 |
+
def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
|
79 |
+
cl_tokens = torch.cat(cl_tokens)
|
80 |
+
cb_tokens = torch.cat(cb_tokens)
|
81 |
+
batch_size = cl_tokens.shape[0]
|
82 |
+
shard_size = 20
|
83 |
+
similarity_matrix = torch.zeros((batch_size, batch_size))
|
84 |
+
for i_start in tqdm(range(0, batch_size, shard_size)):
|
85 |
+
i_end = min(i_start + shard_size, batch_size)
|
86 |
+
shard_i = cl_tokens[i_start:i_end]
|
87 |
+
for j_start in range(0, batch_size, shard_size):
|
88 |
+
j_end = min(j_start + shard_size, batch_size)
|
89 |
+
shard_j = cb_tokens[j_start:j_end]
|
90 |
+
batch_i = shard_i.unsqueeze(1)
|
91 |
+
batch_j = shard_j.unsqueeze(0)
|
92 |
+
pairwise_i = batch_i.expand(-1, shard_size, -1, -1)
|
93 |
+
pairwise_j = batch_j.expand(shard_size, -1, -1, -1)
|
94 |
+
|
95 |
+
similarity_scores, distances = model.combine_features(pairwise_i.reshape(-1, 197, 1024), pairwise_j.reshape(-1, 197, 1024))
|
96 |
+
scores = similarity_scores - 0.1 * distances
|
97 |
+
scores = scores.reshape(shard_size, shard_size)
|
98 |
+
similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
|
99 |
+
return similarity_matrix
|
100 |
+
|
101 |
+
def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
|
102 |
+
model.eval()
|
103 |
+
cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
|
104 |
+
with torch.no_grad():
|
105 |
+
for (x_cl, x_cb, label) in tqdm(test_loader):
|
106 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
107 |
+
x_cl_token = model.get_tokens(x_cl,'contactless')
|
108 |
+
x_cb_token = model.get_tokens(x_cb,'contactbased')
|
109 |
+
label = label.cpu().detach().numpy()
|
110 |
+
cl_feats.append(x_cl_token)
|
111 |
+
cb_feats.append(x_cb_token)
|
112 |
+
cl_labels.append(label)
|
113 |
+
cb_labels.append(label)
|
114 |
+
|
115 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
116 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
117 |
+
|
118 |
+
# CB2CL
|
119 |
+
scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
|
120 |
+
np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+".npy", scores_mat)
|
121 |
+
scores = scores_mat.cpu().detach().numpy().flatten().tolist()
|
122 |
+
labels = torch.eq(cb_label.view(-1,1) - cl_label.view(1,-1),0.0).flatten().tolist()
|
123 |
+
ids_mod = list()
|
124 |
+
for i in labels:
|
125 |
+
if i==True:
|
126 |
+
ids_mod.append(1)
|
127 |
+
else:
|
128 |
+
ids_mod.append(0)
|
129 |
+
fpr,tpr,_ = roc_curve(labels,scores)
|
130 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
131 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
132 |
+
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
133 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
134 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
135 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
136 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
137 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
138 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
139 |
+
fnr = 1 - tpr
|
140 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
141 |
+
roc_auc = auc(fpr, tpr)
|
142 |
+
plt.figure()
|
143 |
+
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
|
144 |
+
plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
|
145 |
+
plt.xlim([0, 1])
|
146 |
+
plt.ylim([0, 1])
|
147 |
+
plt.xlabel('False Positive Rate')
|
148 |
+
plt.ylabel('True Positive Rate')
|
149 |
+
plt.title('ROC Curve CB2CL task1')
|
150 |
+
plt.legend(loc="lower right")
|
151 |
+
plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+str(epoch)+".png", dpi=300, bbox_inches='tight')
|
152 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
153 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
154 |
+
eer_cb2cl = EER * 100
|
155 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
156 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
157 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
158 |
+
cbcltf102 = tar_far_102 * 100
|
159 |
+
cbcltf103 = tar_far_103 * 100
|
160 |
+
cbcltf104 = tar_far_104 * 100
|
161 |
+
cl2cbk1 = compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100
|
162 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
|
163 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
|
164 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
|
165 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
|
166 |
+
torch.cuda.empty_cache()
|
167 |
+
|
168 |
+
return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
|
169 |
+
|
170 |
+
def main():
|
171 |
+
# Training settings
|
172 |
+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
173 |
+
parser.add_argument('--manifest-list', type=list, default=mani_lst,
|
174 |
+
help='list of manifest files')
|
175 |
+
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
176 |
+
help='input batch size for training (default: 64)')
|
177 |
+
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
|
178 |
+
help='input batch size for testing (default: 1000)')
|
179 |
+
parser.add_argument('--epochs', type=int, default=50, metavar='N',
|
180 |
+
help='number of epochs to train (default: 14)')
|
181 |
+
parser.add_argument('--lr_fusion', type=float, default=1.0, metavar='LR',
|
182 |
+
help='learning rate (default: 1.0)')
|
183 |
+
parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
|
184 |
+
help='Learning rate step gamma (default: 0.7)')
|
185 |
+
parser.add_argument('--no-cuda', action='store_true', default=False,
|
186 |
+
help='disables CUDA training')
|
187 |
+
parser.add_argument('--dry-run', action='store_true', default=False,
|
188 |
+
help='quickly check a single pass')
|
189 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
190 |
+
help='random seed (default: 1)')
|
191 |
+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
192 |
+
help='how many batches to wait before logging training status')
|
193 |
+
parser.add_argument('--warmup', type=int, default=2, metavar='N',
|
194 |
+
help='warm up rate for feature extractor')
|
195 |
+
parser.add_argument('--model-name', type=str, default="swinmodel",
|
196 |
+
help='Name of the model for checkpointing')
|
197 |
+
args = parser.parse_args()
|
198 |
+
|
199 |
+
device = torch.device("cuda")
|
200 |
+
model = Model().to(device)
|
201 |
+
ckpt_combined_phase1_ft = "ridgeformer_checkpoints/combined_models_check/phase1_ft_hkpoly.pt"
|
202 |
+
ckpt_combined_phase2 = "ridgeformer_checkpoints/phase2_scratch.pt"
|
203 |
+
|
204 |
+
model.load_pretrained_models(ckpt_combined_phase1_ft, ckpt_combined_phase2)
|
205 |
+
model.freeze_backbone()
|
206 |
+
checkpoint_save_path = "ridgeformer_checkpoints/"
|
207 |
+
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
208 |
+
|
209 |
+
if not os.path.exists("experiment_logs/"+args.model_name):
|
210 |
+
os.mkdir("experiment_logs/"+args.model_name)
|
211 |
+
|
212 |
+
log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_fusion))
|
213 |
+
|
214 |
+
torch.manual_seed(args.seed)
|
215 |
+
|
216 |
+
print("loading Normal RGB images -----------------------------")
|
217 |
+
train_dataset = Combined_original(args.manifest_list,split="train")
|
218 |
+
val_dataset = hktest(split="test")
|
219 |
+
|
220 |
+
balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
|
221 |
+
batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
|
222 |
+
|
223 |
+
train_kwargs = {'batch_sampler': batch_sampler}
|
224 |
+
test_kwargs = {'batch_size': args.test_batch_size}
|
225 |
+
|
226 |
+
if use_cuda:
|
227 |
+
cuda_kwargs = {
|
228 |
+
'num_workers': 1,
|
229 |
+
'pin_memory': True
|
230 |
+
}
|
231 |
+
train_kwargs.update(cuda_kwargs)
|
232 |
+
test_kwargs.update(cuda_kwargs)
|
233 |
+
|
234 |
+
|
235 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
|
236 |
+
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
|
237 |
+
|
238 |
+
print("Number of Trainable Parameters: - ", count_parameters(model))
|
239 |
+
|
240 |
+
loss_func = DualMSLoss_FineGrained()
|
241 |
+
optimizer_fusion = optim.AdamW(
|
242 |
+
[
|
243 |
+
{"params": model.output_logit_mlp.parameters(), "lr":args.lr_fusion},
|
244 |
+
{"params": model.fusion.parameters(), "lr":args.lr_fusion},
|
245 |
+
{"params": model.sep_token, "lr":args.lr_fusion},
|
246 |
+
{"params": model.encoder_layer.parameters(), "lr":args.lr_fusion},
|
247 |
+
|
248 |
+
],
|
249 |
+
weight_decay=0.000001,
|
250 |
+
lr=args.lr_fusion)
|
251 |
+
|
252 |
+
scheduler = MultiStepLR(optimizer_fusion, milestones = [3,6,9,14], gamma=0.5)
|
253 |
+
|
254 |
+
cl2cl_lst,cb2cl_lst,eer_cl2cl_lst,eer_cb2cl_lst,cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list(),list(),list(),list(),list()
|
255 |
+
stepping = 1
|
256 |
+
for epoch in range(1, args.epochs + 1):
|
257 |
+
print(f"running epoch------ {epoch}")
|
258 |
+
avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_fusion], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_fusion)],stepping,log_writer, checkpoint_save_path)
|
259 |
+
|
260 |
+
print(f"Learning Rate for {epoch} for linear = {scheduler.get_last_lr()}")
|
261 |
+
print(f"Learning Rate for {epoch} for swin = {scheduler.get_last_lr()}")
|
262 |
+
|
263 |
+
log_writer.add_scalar('Liner_LR/epoch',scheduler.get_last_lr()[0],epoch)
|
264 |
+
log_writer.add_scalar('Swin_LR/epoch',scheduler.get_last_lr()[0],epoch)
|
265 |
+
|
266 |
+
scheduler.step()
|
267 |
+
|
268 |
+
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch, [args.model_name,str(args.batch_size),str(args.lr_fusion)])
|
269 |
+
cl2cl_lst.append(cl2clk1)
|
270 |
+
cb2cl_lst.append(cl2cbk1)
|
271 |
+
eer_cl2cl_lst.append(eer_cl2cl)
|
272 |
+
eer_cb2cl_lst.append(eer_cb2cl)
|
273 |
+
cbcltf102_lst.append(cbcltf102)
|
274 |
+
cbcltf103_lst.append(cbcltf103)
|
275 |
+
cbcltf104_lst.append(cbcltf104)
|
276 |
+
clcltf102_lst.append(clcltf102)
|
277 |
+
clcltf103_lst.append(clcltf103)
|
278 |
+
clcltf104_lst.append(clcltf104)
|
279 |
+
|
280 |
+
log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
|
281 |
+
log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
|
282 |
+
log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
|
283 |
+
log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
|
284 |
+
log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
|
285 |
+
|
286 |
+
torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_fusion) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
|
287 |
+
log_writer.close()
|
288 |
+
|
289 |
+
print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
|
290 |
+
print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
|
291 |
+
print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
|
292 |
+
print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
|
293 |
+
print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
|
294 |
+
print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
|
295 |
+
print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
|
296 |
+
print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
|
297 |
+
print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
|
298 |
+
print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
|
299 |
+
|
300 |
+
if __name__ == '__main__':
|
301 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
class RetMetric(object):
|
7 |
+
def __init__(self, sim_mat, labels):
|
8 |
+
self.gallery_labels, self.query_labels = labels
|
9 |
+
self.sim_mat = sim_mat
|
10 |
+
self.is_equal_query = False
|
11 |
+
|
12 |
+
def recall_k(self, k=1):
|
13 |
+
m = len(self.sim_mat)
|
14 |
+
|
15 |
+
match_counter = 0
|
16 |
+
|
17 |
+
for i in range(m):
|
18 |
+
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
|
19 |
+
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
|
20 |
+
|
21 |
+
thresh = np.sort(pos_sim)[-2] if self.is_equal_query and len(pos_sim) > 1 else np.max(pos_sim)
|
22 |
+
|
23 |
+
if np.sum(neg_sim > thresh) < k:
|
24 |
+
match_counter += 1
|
25 |
+
return float(match_counter) / m
|
26 |
+
|
27 |
+
class Prev_RetMetric(object):
|
28 |
+
def __init__(self, feats, labels, cl2cl=True):
|
29 |
+
|
30 |
+
if len(feats) == 2 and type(feats) == list:
|
31 |
+
"""
|
32 |
+
feats = [gallery_feats, query_feats]
|
33 |
+
labels = [gallery_labels, query_labels]
|
34 |
+
"""
|
35 |
+
self.is_equal_query = False
|
36 |
+
|
37 |
+
self.gallery_feats, self.query_feats = feats
|
38 |
+
self.gallery_labels, self.query_labels = labels
|
39 |
+
|
40 |
+
else:
|
41 |
+
self.is_equal_query = True
|
42 |
+
self.gallery_feats = self.query_feats = feats
|
43 |
+
self.gallery_labels = self.query_labels = labels
|
44 |
+
|
45 |
+
self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
|
46 |
+
if cl2cl:
|
47 |
+
self.sim_mat = self.sim_mat * (1 - np.identity(self.sim_mat.shape[0]))
|
48 |
+
|
49 |
+
def recall_k(self, k=1):
|
50 |
+
m = len(self.sim_mat)
|
51 |
+
|
52 |
+
match_counter = 0
|
53 |
+
|
54 |
+
for i in range(m):
|
55 |
+
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
|
56 |
+
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
|
57 |
+
|
58 |
+
thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
|
59 |
+
|
60 |
+
if np.sum(neg_sim > thresh) < k:
|
61 |
+
match_counter += 1
|
62 |
+
return float(match_counter) / m
|
63 |
+
|
64 |
+
def compute_recall_at_k(similarity_matrix, p_labels, g_labels, k):
|
65 |
+
num_probes = p_labels.size(0)
|
66 |
+
recall_at_k = 0.0
|
67 |
+
for i in range(num_probes):
|
68 |
+
probe_label = p_labels[i]
|
69 |
+
sim_scores = similarity_matrix[i]
|
70 |
+
sorted_indices = torch.argsort(sim_scores, descending=True)
|
71 |
+
top_k_indices = sorted_indices[:k]
|
72 |
+
correct_in_top_k = any(g_labels[idx] == probe_label for idx in top_k_indices)
|
73 |
+
recall_at_k += correct_in_top_k
|
74 |
+
recall_at_k /= num_probes
|
75 |
+
return recall_at_k
|
76 |
+
|
77 |
+
def count_parameters(model):
|
78 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
79 |
+
|
80 |
+
def l2_norm(input):
|
81 |
+
input_size = input.size()
|
82 |
+
buffer = torch.pow(input, 2)
|
83 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
84 |
+
norm = torch.sqrt(normp)
|
85 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
86 |
+
output = _output.view(input_size)
|
87 |
+
|
88 |
+
return output
|
89 |
+
|
90 |
+
def compute_sharded_cosine_similarity(tensor1, tensor2, shard_size):
|
91 |
+
B, T, D = tensor1.shape
|
92 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
93 |
+
|
94 |
+
for start_idx1 in tqdm(range(0, B, shard_size)):
|
95 |
+
end_idx1 = min(start_idx1 + shard_size, B)
|
96 |
+
|
97 |
+
for start_idx2 in range(0, B, shard_size):
|
98 |
+
end_idx2 = min(start_idx2 + shard_size, B)
|
99 |
+
|
100 |
+
# Get the shard
|
101 |
+
shard_tensor1 = tensor1[start_idx1:end_idx1]
|
102 |
+
shard_tensor2 = tensor2[start_idx2:end_idx2]
|
103 |
+
|
104 |
+
# Reshape and expand
|
105 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
106 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
107 |
+
|
108 |
+
# Compute cosine similarity for the shard
|
109 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
110 |
+
|
111 |
+
# Sum up the cosine similarities
|
112 |
+
average_sim_matrix[start_idx1:end_idx1, start_idx2:end_idx2] += torch.sum(shard_cos_sim, dim=[2, 3])
|
113 |
+
|
114 |
+
# Normalize by the total number of elements (T*T)
|
115 |
+
average_sim_matrix /= (T * T)
|
116 |
+
|
117 |
+
return average_sim_matrix
|