bacformer-masked-MAG / utils_bacformer.py
macwiatrak's picture
Upload BacformerForMaskedGM
d57a711 verified
import torch
from torch.nn.functional import cross_entropy, softmax
from .configuration_bacformer import SPECIAL_TOKENS_DICT
def compute_contrastive_loss(
protein_embeddings: torch.Tensor,
last_hidden_state: torch.Tensor,
special_tokens_mask: torch.Tensor,
) -> torch.Tensor:
"""Compute contrastive loss between protein embeddings and masked items."""
# keep protein embeddings and masked items
# ensure the batch size is 1, the model currently does not work with batch size > 1
assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1
# subset to mask and protein embedding tokens
special_tokens_mask = special_tokens_mask.squeeze(0)
mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
)
protein_embeddings = protein_embeddings.squeeze(0)[mask]
last_hidden_state = last_hidden_state.squeeze(0)[mask]
# Normalize embeddings
last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)
# Compute similarity matrix and loss as before
similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)
n_prots = protein_embeddings.shape[0]
labels = torch.arange(n_prots).to(protein_embeddings.device)
# Compute the loss
loss = cross_entropy(similarity_matrix, labels)
return loss
def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
"""
Keep only top_k logits and set the rest to -inf.
Args:
logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
top_k (int): The number of highest probability logits to keep.
Returns
-------
torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
"""
if top_k <= 0:
return logits
# Find top_k values
top_k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, top_k, dim=-1)
# Get the smallest logit in the top_k
min_vals = vals[:, -1].unsqueeze(-1)
# Mask all logits that are < this min value
mask = logits < min_vals
logits[mask] = float("-inf")
return logits
def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
"""
Keep the smallest set of logits whose cumulative probability >= top_p.
Args:
logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
top_p (float): Cumulative probability threshold.
Returns
-------
torch.Tensor: Filtered logits where only tokens within the top_p cumulative
probability mass are kept; the rest are set to -inf.
"""
if top_p >= 1.0:
return logits
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
# Identify where cumulative probability exceeds top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the mask to ensure we always keep at least one token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
# Scatter to replicate the mask in the original ordering
for i in range(logits.size(0)):
remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
logits[i, remove_indices] = float("-inf")
return logits
def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
"""Helper function to reshape attn_mask to 3D from 2D"""
assert (
len(attn_mask.shape) == 2
), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"
bs, seq_len = attn_mask.shape
attn_mask = attn_mask.view(bs, 1, 1, seq_len)
attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
return attn_mask