|
from pytorch_metric_learning import losses |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init |
|
import torchvision.models as models |
|
from torch.autograd import Variable |
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
from torch.nn.utils.weight_norm import weight_norm |
|
import torch.backends.cudnn as cudnn |
|
from torch.nn.utils.clip_grad import clip_grad_norm |
|
import numpy as np |
|
import os |
|
import torch.nn.functional as F |
|
import itertools |
|
|
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
class DualMSLoss_FineGrained(nn.Module): |
|
""" |
|
Compute contrastive loss |
|
""" |
|
def __init__(self, margin=0, max_violation=False): |
|
super(DualMSLoss_FineGrained, self).__init__() |
|
self.margin = margin |
|
self.max_violation = max_violation |
|
self.thresh = 0.5 |
|
self.margin = 0.7 |
|
self.scale_pos = 2 |
|
self.scale_neg = 40.0 |
|
|
|
def ms_sample(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
neg_mask = 1 - pos_mask |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def ms_sample_cbcb_clcl(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
|
|
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda() |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def ms_sample_cbcb_clcl_trans(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
|
|
n_sha = pos_mask.shape[0] |
|
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool) |
|
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1) |
|
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1) |
|
|
|
neg_mask = 1-pos_mask |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size): |
|
B, T, D = tensor1.shape |
|
average_sim_matrix = torch.zeros((B, B), device=tensor1.device) |
|
|
|
for start_idx in range(0, T, shard_size): |
|
end_idx = min(start_idx + shard_size, T) |
|
|
|
|
|
shard_tensor1 = tensor1[:, start_idx:end_idx, :] |
|
shard_tensor2 = tensor2[:, start_idx:end_idx, :] |
|
|
|
|
|
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3) |
|
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2) |
|
|
|
|
|
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1) |
|
|
|
|
|
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3]) |
|
|
|
|
|
average_sim_matrix /= (T * T) |
|
|
|
return average_sim_matrix |
|
|
|
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device): |
|
|
|
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless)) |
|
n = sim_mat_clcl.shape[0] |
|
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased)) |
|
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless)) |
|
|
|
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda() |
|
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda() |
|
|
|
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda() |
|
return loss4 + loss2 + loss3 |
|
|
|
def l2_norm(self, input): |
|
input_size = input.size() |
|
buffer = torch.pow(input, 2) |
|
normp = torch.sum(buffer, 1).add_(1e-12) |
|
norm = torch.sqrt(normp) |
|
_output = torch.div(input, norm.view(-1, 1).expand_as(input)) |
|
output = _output.view(input_size) |
|
|
|
return output |
|
|
|
class DualMSLoss_FineGrained_domain_agnostic(nn.Module): |
|
""" |
|
Compute contrastive loss |
|
""" |
|
def __init__(self, margin=0, max_violation=False): |
|
super(DualMSLoss_FineGrained_domain_agnostic, self).__init__() |
|
self.margin = margin |
|
self.max_violation = max_violation |
|
self.thresh = 0.5 |
|
self.margin = 0.5 |
|
self.scale_pos = 2 |
|
self.scale_neg = 40.0 |
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
def ms_sample(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
neg_mask = 1 - pos_mask |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def ms_sample_cbcb_clcl(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
|
|
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda() |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def ms_sample_cbcb_clcl_trans(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
|
|
n_sha = pos_mask.shape[0] |
|
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool) |
|
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1) |
|
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1) |
|
|
|
neg_mask = 1-pos_mask |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size): |
|
B, T, D = tensor1.shape |
|
average_sim_matrix = torch.zeros((B, B), device=tensor1.device) |
|
|
|
for start_idx in range(0, T, shard_size): |
|
end_idx = min(start_idx + shard_size, T) |
|
|
|
|
|
shard_tensor1 = tensor1[:, start_idx:end_idx, :] |
|
shard_tensor2 = tensor2[:, start_idx:end_idx, :] |
|
|
|
|
|
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3) |
|
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2) |
|
|
|
|
|
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1) |
|
|
|
|
|
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3]) |
|
|
|
|
|
average_sim_matrix /= (T * T) |
|
|
|
return average_sim_matrix |
|
|
|
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt): |
|
|
|
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless)) |
|
n = sim_mat_clcl.shape[0] |
|
|
|
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased)) |
|
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless)) |
|
|
|
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda() |
|
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda() |
|
|
|
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda() |
|
|
|
pred = torch.cat([domain_class_cl, domain_class_cb]) |
|
gt = torch.cat([domain_class_cl_gt, domain_class_cb_gt]) |
|
|
|
domain_class_loss = self.criterion(pred,gt) |
|
return loss4 + loss2 + loss3 + (3*domain_class_loss) |
|
|
|
def l2_norm(self, input): |
|
input_size = input.size() |
|
buffer = torch.pow(input, 2) |
|
normp = torch.sum(buffer, 1).add_(1e-12) |
|
norm = torch.sqrt(normp) |
|
_output = torch.div(input, norm.view(-1, 1).expand_as(input)) |
|
output = _output.view(input_size) |
|
|
|
return output |
|
|
|
class DualMSLoss_FineGrained_domain_agnostic_ft(nn.Module): |
|
""" |
|
Compute contrastive loss |
|
""" |
|
def __init__(self, margin=0, max_violation=False): |
|
super(DualMSLoss_FineGrained_domain_agnostic_ft, self).__init__() |
|
self.margin = margin |
|
self.max_violation = max_violation |
|
self.thresh = 0.5 |
|
self.margin = 0.7 |
|
self.scale_pos = 2 |
|
self.scale_neg = 40.0 |
|
self.criterion = nn.CrossEntropyLoss() |
|
|
|
def ms_sample(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
neg_mask = 1 - pos_mask |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def ms_sample_cbcb_clcl(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
|
|
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda() |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def ms_sample_cbcb_clcl_trans(self,sim_mat,label): |
|
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh)) |
|
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh)) |
|
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda() |
|
|
|
n_sha = pos_mask.shape[0] |
|
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool) |
|
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1) |
|
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1) |
|
|
|
neg_mask = 1-pos_mask |
|
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16) |
|
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16) |
|
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True) |
|
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True) |
|
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1) |
|
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1) |
|
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos |
|
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg |
|
|
|
return pos_loss + neg_loss |
|
|
|
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size): |
|
B, T, D = tensor1.shape |
|
average_sim_matrix = torch.zeros((B, B), device=tensor1.device) |
|
|
|
for start_idx in range(0, T, shard_size): |
|
end_idx = min(start_idx + shard_size, T) |
|
|
|
|
|
shard_tensor1 = tensor1[:, start_idx:end_idx, :] |
|
shard_tensor2 = tensor2[:, start_idx:end_idx, :] |
|
|
|
|
|
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3) |
|
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2) |
|
|
|
|
|
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1) |
|
|
|
|
|
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3]) |
|
|
|
|
|
average_sim_matrix /= (T * T) |
|
|
|
return average_sim_matrix |
|
|
|
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt): |
|
|
|
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless)) |
|
n = sim_mat_clcl.shape[0] |
|
|
|
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased)) |
|
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless)) |
|
|
|
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda() |
|
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda() |
|
|
|
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda() |
|
|
|
return loss4 + loss2 + loss3 |
|
|
|
def l2_norm(self, input): |
|
input_size = input.size() |
|
buffer = torch.pow(input, 2) |
|
normp = torch.sum(buffer, 1).add_(1e-12) |
|
norm = torch.sqrt(normp) |
|
_output = torch.div(input, norm.view(-1, 1).expand_as(input)) |
|
output = _output.view(input_size) |
|
|
|
return output |
|
|