|
import torch |
|
from torch.nn.functional import cross_entropy, softmax |
|
|
|
from .configuration_bacformer import SPECIAL_TOKENS_DICT |
|
|
|
|
|
def compute_contrastive_loss( |
|
protein_embeddings: torch.Tensor, |
|
last_hidden_state: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
) -> torch.Tensor: |
|
"""Compute contrastive loss between protein embeddings and masked items.""" |
|
|
|
|
|
assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1 |
|
|
|
|
|
special_tokens_mask = special_tokens_mask.squeeze(0) |
|
mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | ( |
|
special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"] |
|
) |
|
protein_embeddings = protein_embeddings.squeeze(0)[mask] |
|
last_hidden_state = last_hidden_state.squeeze(0)[mask] |
|
|
|
|
|
last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True) |
|
protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True) |
|
|
|
|
|
similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T) |
|
|
|
n_prots = protein_embeddings.shape[0] |
|
labels = torch.arange(n_prots).to(protein_embeddings.device) |
|
|
|
|
|
loss = cross_entropy(similarity_matrix, labels) |
|
return loss |
|
|
|
|
|
def top_k_filtering(logits: torch.Tensor, top_k: int = 50): |
|
""" |
|
Keep only top_k logits and set the rest to -inf. |
|
|
|
Args: |
|
logits (torch.Tensor): Logits of shape (batch_size, vocab_size). |
|
top_k (int): The number of highest probability logits to keep. |
|
|
|
Returns |
|
------- |
|
torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf. |
|
""" |
|
if top_k <= 0: |
|
return logits |
|
|
|
|
|
top_k = min(top_k, logits.size(-1)) |
|
vals, idx = torch.topk(logits, top_k, dim=-1) |
|
|
|
min_vals = vals[:, -1].unsqueeze(-1) |
|
|
|
mask = logits < min_vals |
|
logits[mask] = float("-inf") |
|
return logits |
|
|
|
|
|
def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9): |
|
""" |
|
Keep the smallest set of logits whose cumulative probability >= top_p. |
|
|
|
Args: |
|
logits (torch.Tensor): Logits of shape (batch_size, vocab_size). |
|
top_p (float): Cumulative probability threshold. |
|
|
|
Returns |
|
------- |
|
torch.Tensor: Filtered logits where only tokens within the top_p cumulative |
|
probability mass are kept; the rest are set to -inf. |
|
""" |
|
if top_p >= 1.0: |
|
return logits |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = False |
|
|
|
|
|
for i in range(logits.size(0)): |
|
remove_indices = sorted_indices[i, sorted_indices_to_remove[i]] |
|
logits[i, remove_indices] = float("-inf") |
|
|
|
return logits |
|
|
|
|
|
def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int): |
|
"""Helper function to reshape attn_mask to 3D from 2D""" |
|
assert ( |
|
len(attn_mask.shape) == 2 |
|
), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}" |
|
|
|
bs, seq_len = attn_mask.shape |
|
attn_mask = attn_mask.view(bs, 1, 1, seq_len) |
|
attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1) |
|
attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len) |
|
return attn_mask |
|
|