Ridgeformer / utils.py
spandey8's picture
Upload 11 files
007d3b9 verified
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
class RetMetric(object):
def __init__(self, sim_mat, labels):
self.gallery_labels, self.query_labels = labels
self.sim_mat = sim_mat
self.is_equal_query = False
def recall_k(self, k=1):
m = len(self.sim_mat)
match_counter = 0
for i in range(m):
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
thresh = np.sort(pos_sim)[-2] if self.is_equal_query and len(pos_sim) > 1 else np.max(pos_sim)
if np.sum(neg_sim > thresh) < k:
match_counter += 1
return float(match_counter) / m
class Prev_RetMetric(object):
def __init__(self, feats, labels, cl2cl=True):
if len(feats) == 2 and type(feats) == list:
"""
feats = [gallery_feats, query_feats]
labels = [gallery_labels, query_labels]
"""
self.is_equal_query = False
self.gallery_feats, self.query_feats = feats
self.gallery_labels, self.query_labels = labels
else:
self.is_equal_query = True
self.gallery_feats = self.query_feats = feats
self.gallery_labels = self.query_labels = labels
self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
if cl2cl:
self.sim_mat = self.sim_mat * (1 - np.identity(self.sim_mat.shape[0]))
def recall_k(self, k=1):
m = len(self.sim_mat)
match_counter = 0
for i in range(m):
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
if np.sum(neg_sim > thresh) < k:
match_counter += 1
return float(match_counter) / m
def compute_recall_at_k(similarity_matrix, p_labels, g_labels, k):
num_probes = p_labels.size(0)
recall_at_k = 0.0
for i in range(num_probes):
probe_label = p_labels[i]
sim_scores = similarity_matrix[i]
sorted_indices = torch.argsort(sim_scores, descending=True)
top_k_indices = sorted_indices[:k]
correct_in_top_k = any(g_labels[idx] == probe_label for idx in top_k_indices)
recall_at_k += correct_in_top_k
recall_at_k /= num_probes
return recall_at_k
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def l2_norm(input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-12)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
def compute_sharded_cosine_similarity(tensor1, tensor2, shard_size):
B, T, D = tensor1.shape
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
for start_idx1 in tqdm(range(0, B, shard_size)):
end_idx1 = min(start_idx1 + shard_size, B)
for start_idx2 in range(0, B, shard_size):
end_idx2 = min(start_idx2 + shard_size, B)
# Get the shard
shard_tensor1 = tensor1[start_idx1:end_idx1]
shard_tensor2 = tensor2[start_idx2:end_idx2]
# Reshape and expand
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
# Compute cosine similarity for the shard
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
# Sum up the cosine similarities
average_sim_matrix[start_idx1:end_idx1, start_idx2:end_idx2] += torch.sum(shard_cos_sim, dim=[2, 3])
# Normalize by the total number of elements (T*T)
average_sim_matrix /= (T * T)
return average_sim_matrix