macwiatrak commited on
Commit
5969223
·
verified ·
1 Parent(s): 7a26e46

Upload BacformerForCausalGM

Browse files
Files changed (2) hide show
  1. modeling_bacformer.py +5 -126
  2. utils_bacformer.py +109 -0
modeling_bacformer.py CHANGED
@@ -16,112 +16,8 @@ from torch.nn.functional import (
16
  from transformers import PreTrainedModel
17
  from transformers.utils import ModelOutput
18
 
19
- from bacformer_model.configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig
20
-
21
-
22
- def compute_contrastive_loss(
23
- protein_embeddings: torch.Tensor,
24
- last_hidden_state: torch.Tensor,
25
- special_tokens_mask: torch.Tensor,
26
- ) -> torch.Tensor:
27
- """Compute contrastive loss between protein embeddings and masked items."""
28
- # keep protein embeddings and masked items
29
- # ensure the batch size is 1, the model currently does not work with batch size > 1
30
- assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1
31
-
32
- # subset to mask and protein embedding tokens
33
- special_tokens_mask = special_tokens_mask.squeeze(0)
34
- mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
35
- special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
36
- )
37
- protein_embeddings = protein_embeddings.squeeze(0)[mask]
38
- last_hidden_state = last_hidden_state.squeeze(0)[mask]
39
-
40
- # Normalize embeddings
41
- last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
42
- protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)
43
-
44
- # Compute similarity matrix and loss as before
45
- similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)
46
-
47
- n_prots = protein_embeddings.shape[0]
48
- labels = torch.arange(n_prots).to(protein_embeddings.device)
49
-
50
- # Compute the loss
51
- loss = cross_entropy(similarity_matrix, labels)
52
- return loss
53
-
54
-
55
- def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
56
- """
57
- Keep only top_k logits and set the rest to -inf.
58
-
59
- Args:
60
- logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
61
- top_k (int): The number of highest probability logits to keep.
62
-
63
- Returns
64
- -------
65
- torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
66
- """
67
- if top_k <= 0:
68
- return logits
69
-
70
- # Find top_k values
71
- top_k = min(top_k, logits.size(-1))
72
- vals, idx = torch.topk(logits, top_k, dim=-1)
73
- # Get the smallest logit in the top_k
74
- min_vals = vals[:, -1].unsqueeze(-1)
75
- # Mask all logits that are < this min value
76
- mask = logits < min_vals
77
- logits[mask] = float("-inf")
78
- return logits
79
-
80
-
81
- def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
82
- """
83
- Keep the smallest set of logits whose cumulative probability >= top_p.
84
-
85
- Args:
86
- logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
87
- top_p (float): Cumulative probability threshold.
88
-
89
- Returns
90
- -------
91
- torch.Tensor: Filtered logits where only tokens within the top_p cumulative
92
- probability mass are kept; the rest are set to -inf.
93
- """
94
- if top_p >= 1.0:
95
- return logits
96
-
97
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
98
- cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
99
-
100
- # Identify where cumulative probability exceeds top_p
101
- sorted_indices_to_remove = cumulative_probs > top_p
102
- # Shift the mask to ensure we always keep at least one token
103
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
104
- sorted_indices_to_remove[..., 0] = False
105
-
106
- # Scatter to replicate the mask in the original ordering
107
- for i in range(logits.size(0)):
108
- remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
109
- logits[i, remove_indices] = float("-inf")
110
-
111
- return logits
112
-
113
-
114
- def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
115
- """Helper function to reshape attn_mask to 3D from 2D"""
116
- assert (
117
- len(attn_mask.shape) == 2
118
- ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"
119
-
120
- bs, seq_len = attn_mask.shape
121
- attn_mask = attn_mask.view(bs, 1, 1, seq_len)
122
- attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
123
- attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
124
- return attn_mask
125
 
126
 
127
  @dataclass
@@ -186,23 +82,6 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
186
  return freqs_cos, freqs_sin
187
 
188
 
189
- def symmetrize(x):
190
- """Make layer symmetric in final two dimensions, used for protein-protein interaction prediction."""
191
- return x + x.transpose(-1, -2)
192
-
193
-
194
- def average_product_correct(x):
195
- """Perform average product correct, used for protein-protein interaction prediction."""
196
- a1 = x.sum(-1, keepdims=True)
197
- a2 = x.sum(-2, keepdims=True)
198
- a12 = x.sum((-1, -2), keepdims=True)
199
-
200
- avg = a1 * a2
201
- # avg.div_(a12) # in-place to reduce memory
202
- normalized = x - avg.div_(a12)
203
- return normalized
204
-
205
-
206
  def scaled_dot_product_attention_w_attn_weights(
207
  query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
208
  ) -> tuple[torch.Tensor, torch.Tensor]:
@@ -416,7 +295,7 @@ class BacformerTransformerEncoder(nn.Module):
416
  class BacformerEmbeddings(nn.Module):
417
  """Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings."""
418
 
419
- def __init__(self, config: BacformerConfig):
420
  super().__init__()
421
  self.config = config
422
  self.linear = nn.Linear(config.hidden_size, config.hidden_size)
@@ -469,7 +348,7 @@ class BacformerProteinFamilyEmbeddings(nn.Module):
469
 
470
  def __init__(
471
  self,
472
- config: BacformerConfig,
473
  protein_family_embeddings: torch.Tensor = None,
474
  token_type_embeddings: torch.Tensor = None,
475
  special_tokens_embeddings: torch.Tensor = None,
@@ -573,7 +452,7 @@ class BacformerProteinFamilyEmbeddings(nn.Module):
573
  class BacformerEncoder(nn.Module):
574
  """Bacformer encoder model"""
575
 
576
- def __init__(self, config: BacformerConfig):
577
  super().__init__()
578
  self.config = config
579
 
 
16
  from transformers import PreTrainedModel
17
  from transformers.utils import ModelOutput
18
 
19
+ from .configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig
20
+ from .utils_bacformer import compute_contrastive_loss, create_4d_from_2d_attn_mask, top_k_filtering, top_p_filtering
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  @dataclass
 
82
  return freqs_cos, freqs_sin
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def scaled_dot_product_attention_w_attn_weights(
86
  query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
87
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
295
  class BacformerEmbeddings(nn.Module):
296
  """Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings."""
297
 
298
+ def __init__(self, config):
299
  super().__init__()
300
  self.config = config
301
  self.linear = nn.Linear(config.hidden_size, config.hidden_size)
 
348
 
349
  def __init__(
350
  self,
351
+ config,
352
  protein_family_embeddings: torch.Tensor = None,
353
  token_type_embeddings: torch.Tensor = None,
354
  special_tokens_embeddings: torch.Tensor = None,
 
452
  class BacformerEncoder(nn.Module):
453
  """Bacformer encoder model"""
454
 
455
+ def __init__(self, config):
456
  super().__init__()
457
  self.config = config
458
 
utils_bacformer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import cross_entropy, softmax
3
+
4
+ from .configuration_bacformer import SPECIAL_TOKENS_DICT
5
+
6
+
7
+ def compute_contrastive_loss(
8
+ protein_embeddings: torch.Tensor,
9
+ last_hidden_state: torch.Tensor,
10
+ special_tokens_mask: torch.Tensor,
11
+ ) -> torch.Tensor:
12
+ """Compute contrastive loss between protein embeddings and masked items."""
13
+ # keep protein embeddings and masked items
14
+ # ensure the batch size is 1, the model currently does not work with batch size > 1
15
+ assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1
16
+
17
+ # subset to mask and protein embedding tokens
18
+ special_tokens_mask = special_tokens_mask.squeeze(0)
19
+ mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
20
+ special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
21
+ )
22
+ protein_embeddings = protein_embeddings.squeeze(0)[mask]
23
+ last_hidden_state = last_hidden_state.squeeze(0)[mask]
24
+
25
+ # Normalize embeddings
26
+ last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
27
+ protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)
28
+
29
+ # Compute similarity matrix and loss as before
30
+ similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)
31
+
32
+ n_prots = protein_embeddings.shape[0]
33
+ labels = torch.arange(n_prots).to(protein_embeddings.device)
34
+
35
+ # Compute the loss
36
+ loss = cross_entropy(similarity_matrix, labels)
37
+ return loss
38
+
39
+
40
+ def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
41
+ """
42
+ Keep only top_k logits and set the rest to -inf.
43
+
44
+ Args:
45
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
46
+ top_k (int): The number of highest probability logits to keep.
47
+
48
+ Returns
49
+ -------
50
+ torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
51
+ """
52
+ if top_k <= 0:
53
+ return logits
54
+
55
+ # Find top_k values
56
+ top_k = min(top_k, logits.size(-1))
57
+ vals, idx = torch.topk(logits, top_k, dim=-1)
58
+ # Get the smallest logit in the top_k
59
+ min_vals = vals[:, -1].unsqueeze(-1)
60
+ # Mask all logits that are < this min value
61
+ mask = logits < min_vals
62
+ logits[mask] = float("-inf")
63
+ return logits
64
+
65
+
66
+ def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
67
+ """
68
+ Keep the smallest set of logits whose cumulative probability >= top_p.
69
+
70
+ Args:
71
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
72
+ top_p (float): Cumulative probability threshold.
73
+
74
+ Returns
75
+ -------
76
+ torch.Tensor: Filtered logits where only tokens within the top_p cumulative
77
+ probability mass are kept; the rest are set to -inf.
78
+ """
79
+ if top_p >= 1.0:
80
+ return logits
81
+
82
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
83
+ cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
84
+
85
+ # Identify where cumulative probability exceeds top_p
86
+ sorted_indices_to_remove = cumulative_probs > top_p
87
+ # Shift the mask to ensure we always keep at least one token
88
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
89
+ sorted_indices_to_remove[..., 0] = False
90
+
91
+ # Scatter to replicate the mask in the original ordering
92
+ for i in range(logits.size(0)):
93
+ remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
94
+ logits[i, remove_indices] = float("-inf")
95
+
96
+ return logits
97
+
98
+
99
+ def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
100
+ """Helper function to reshape attn_mask to 3D from 2D"""
101
+ assert (
102
+ len(attn_mask.shape) == 2
103
+ ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"
104
+
105
+ bs, seq_len = attn_mask.shape
106
+ attn_mask = attn_mask.view(bs, 1, 1, seq_len)
107
+ attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
108
+ attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
109
+ return attn_mask