File size: 4,006 Bytes
5969223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch.nn.functional import cross_entropy, softmax

from .configuration_bacformer import SPECIAL_TOKENS_DICT


def compute_contrastive_loss(
    protein_embeddings: torch.Tensor,
    last_hidden_state: torch.Tensor,
    special_tokens_mask: torch.Tensor,
) -> torch.Tensor:
    """Compute contrastive loss between protein embeddings and masked items."""
    # keep protein embeddings and masked items
    # ensure the batch size is 1, the model currently does not work with batch size > 1
    assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1

    # subset to mask and protein embedding tokens
    special_tokens_mask = special_tokens_mask.squeeze(0)
    mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
        special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
    )
    protein_embeddings = protein_embeddings.squeeze(0)[mask]
    last_hidden_state = last_hidden_state.squeeze(0)[mask]

    # Normalize embeddings
    last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
    protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)

    # Compute similarity matrix and loss as before
    similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)

    n_prots = protein_embeddings.shape[0]
    labels = torch.arange(n_prots).to(protein_embeddings.device)

    # Compute the loss
    loss = cross_entropy(similarity_matrix, labels)
    return loss


def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
    """
    Keep only top_k logits and set the rest to -inf.

    Args:
        logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
        top_k (int): The number of highest probability logits to keep.

    Returns
    -------
        torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
    """
    if top_k <= 0:
        return logits

    # Find top_k values
    top_k = min(top_k, logits.size(-1))
    vals, idx = torch.topk(logits, top_k, dim=-1)
    # Get the smallest logit in the top_k
    min_vals = vals[:, -1].unsqueeze(-1)
    # Mask all logits that are < this min value
    mask = logits < min_vals
    logits[mask] = float("-inf")
    return logits


def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
    """
    Keep the smallest set of logits whose cumulative probability >= top_p.

    Args:
        logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
        top_p (float): Cumulative probability threshold.

    Returns
    -------
        torch.Tensor: Filtered logits where only tokens within the top_p cumulative
                      probability mass are kept; the rest are set to -inf.
    """
    if top_p >= 1.0:
        return logits

    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)

    # Identify where cumulative probability exceeds top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the mask to ensure we always keep at least one token
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False

    # Scatter to replicate the mask in the original ordering
    for i in range(logits.size(0)):
        remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
        logits[i, remove_indices] = float("-inf")

    return logits


def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
    """Helper function to reshape attn_mask to 3D from 2D"""
    assert (
        len(attn_mask.shape) == 2
    ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"

    bs, seq_len = attn_mask.shape
    attn_mask = attn_mask.view(bs, 1, 1, seq_len)
    attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
    attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
    return attn_mask