macwiatrak commited on
Commit
9c2e756
·
verified ·
1 Parent(s): 9e8002a

Upload BacformerForMaskedGM

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. modeling_bacformer.py +1340 -0
  3. utils_bacformer.py +109 -0
config.json CHANGED
@@ -5,7 +5,8 @@
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
- "AutoConfig": "configuration_bacformer.BacformerConfig"
 
9
  },
10
  "batch_size": 1,
11
  "ckpt_path": null,
 
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "auto_map": {
8
+ "AutoConfig": "configuration_bacformer.BacformerConfig",
9
+ "AutoModelForMaskedLM": "modeling_bacformer.BacformerForMaskedGM"
10
  },
11
  "batch_size": 1,
12
  "ckpt_path": null,
modeling_bacformer.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Optional, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn.functional import (
9
+ binary_cross_entropy_with_logits,
10
+ cross_entropy,
11
+ gelu,
12
+ mse_loss,
13
+ scaled_dot_product_attention,
14
+ softmax,
15
+ )
16
+ from transformers import PreTrainedModel
17
+ from transformers.utils import ModelOutput
18
+
19
+ from .configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig
20
+ from .utils_bacformer import compute_contrastive_loss, create_4d_from_2d_attn_mask, top_k_filtering, top_p_filtering
21
+
22
+
23
+ @dataclass
24
+ class BacformerModelOutput(ModelOutput):
25
+ """Base class for outputs of the Bacformer model."""
26
+
27
+ loss: torch.FloatTensor | None = None
28
+ logits: torch.FloatTensor = None
29
+ last_hidden_state: torch.FloatTensor | None = None
30
+ attentions: Union[torch.FloatTensor, None] = None
31
+ pooler_output: torch.FloatTensor | None = None
32
+
33
+
34
+ # Taken from facebookresearch/llama/model.py
35
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
36
+ """Reshape the rotary embeddings for broadcasting."""
37
+ ndim = x.ndim
38
+ assert 0 <= 1 < ndim
39
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
40
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
41
+ return freqs_cis.view(*shape)
42
+
43
+
44
+ # Taken from facebookresearch/llama/model.py
45
+ def apply_rotary_emb(
46
+ xq: torch.Tensor,
47
+ xk: torch.Tensor,
48
+ freqs_cos: torch.Tensor,
49
+ freqs_sin: torch.Tensor,
50
+ ) -> tuple[torch.Tensor, torch.Tensor]:
51
+ """Apply rotary embeddings to the query and key tensors."""
52
+ # reshape xq and xk to match the complex representation
53
+ xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
54
+ xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
55
+
56
+ # reshape freqs_cos and freqs_sin for broadcasting
57
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
58
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
59
+
60
+ # apply rotation using real numbers
61
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
62
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
63
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
64
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
65
+
66
+ # flatten last two dimensions
67
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
68
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
69
+
70
+ return xq_out.type_as(xq), xk_out.type_as(xk)
71
+
72
+
73
+ # Taken from facebookresearch/llama/model.py
74
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
75
+ """Precompute the freqs cis for rotary embeddings."""
76
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
77
+ t = torch.arange(end, device=freqs.device) # type: ignore
78
+ freqs = torch.outer(t, freqs).float() # type: ignore
79
+
80
+ freqs_cos = torch.cos(freqs) # real part
81
+ freqs_sin = torch.sin(freqs) # imaginary part
82
+ return freqs_cos, freqs_sin
83
+
84
+
85
+ def scaled_dot_product_attention_w_attn_weights(
86
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
87
+ ) -> tuple[torch.Tensor, torch.Tensor]:
88
+ """PyTorch Native implementation, modified to return attention weights."""
89
+ L, S = query.size(-2), key.size(-2)
90
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
91
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
92
+ if is_causal:
93
+ assert attn_mask is None
94
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
95
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
96
+ attn_bias.to(query.dtype)
97
+
98
+ if attn_mask is not None:
99
+ if attn_mask.dtype == torch.bool:
100
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
101
+ else:
102
+ attn_bias += attn_mask
103
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
104
+ attn_weight += attn_bias
105
+ attn_weight = torch.softmax(attn_weight, dim=-1)
106
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
107
+ attn_output = attn_weight @ value
108
+ return attn_output, attn_weight
109
+
110
+
111
+ class RotarySelfAttention(nn.Module):
112
+ """Rotary self-attention module."""
113
+
114
+ def __init__(
115
+ self,
116
+ embed_dim: int,
117
+ num_heads: int,
118
+ dropout: float = 0.1,
119
+ ):
120
+ super().__init__()
121
+ self.embed_dim = embed_dim
122
+ self.num_heads = num_heads
123
+ self.dim_head = embed_dim // num_heads
124
+ self.dropout_rate = dropout
125
+
126
+ self.q = nn.Linear(embed_dim, embed_dim, bias=False)
127
+ self.k = nn.Linear(embed_dim, embed_dim, bias=False)
128
+ self.v = nn.Linear(embed_dim, embed_dim, bias=False)
129
+ self.att_proj_linear = nn.Linear(embed_dim, embed_dim)
130
+
131
+ def forward(
132
+ self,
133
+ x: torch.Tensor,
134
+ attn_mask: torch.Tensor,
135
+ freqs_cos: torch.Tensor,
136
+ freqs_sin: torch.Tensor,
137
+ is_causal: bool = False,
138
+ return_attn_weights: bool = False,
139
+ ):
140
+ """Forward pass for the rotary self-attention module."""
141
+ batch_size, seq_len, _ = x.shape
142
+ xq, xk, xv = self.q(x), self.k(x), self.v(x)
143
+ # Reshape for rotary embeddings
144
+ xq = xq.view(batch_size, seq_len, self.num_heads, self.dim_head)
145
+ xk = xk.view(batch_size, seq_len, self.num_heads, self.dim_head)
146
+ xv = xv.view(batch_size, seq_len, self.num_heads, self.dim_head)
147
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
148
+
149
+ # Reshape for attention calculation: (b_sz, n_head, s_len, d_head)
150
+ xq = xq.transpose(1, 2)
151
+ xk = xk.transpose(1, 2)
152
+ xv = xv.transpose(1, 2)
153
+
154
+ attn_weights = None
155
+ if return_attn_weights:
156
+ att, attn_weights = scaled_dot_product_attention_w_attn_weights(
157
+ query=xq,
158
+ key=xk,
159
+ value=xv,
160
+ attn_mask=attn_mask,
161
+ dropout_p=self.dropout_rate if self.training else 0.0,
162
+ is_causal=is_causal,
163
+ )
164
+ else:
165
+ att = scaled_dot_product_attention(
166
+ query=xq,
167
+ key=xk,
168
+ value=xv,
169
+ attn_mask=attn_mask,
170
+ dropout_p=self.dropout_rate if self.training else 0.0,
171
+ is_causal=is_causal,
172
+ )
173
+ # Shape (b_sz, s_len, n_head, d_head)
174
+ out = att.transpose(1, 2).contiguous()
175
+ out = out.view(batch_size, seq_len, self.num_heads * self.dim_head)
176
+
177
+ return self.att_proj_linear(out), attn_weights
178
+
179
+
180
+ class BacformerTransformerLayer(nn.Module):
181
+ """Own implementation of transformer layer which uses pytorch native MHA but returns attention weights"""
182
+
183
+ def __init__(
184
+ self,
185
+ hidden_size: int,
186
+ intermediate_size: int,
187
+ num_attention_heads: int,
188
+ dropout: float = 0.1,
189
+ activation: Literal["gelu", "relu"] = "gelu",
190
+ ):
191
+ super().__init__()
192
+ self.self_mha = RotarySelfAttention(
193
+ embed_dim=hidden_size,
194
+ num_heads=num_attention_heads,
195
+ dropout=dropout,
196
+ )
197
+
198
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
199
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
200
+ self.activation = nn.GELU() if activation == "gelu" else nn.ReLU()
201
+ self.norm1 = nn.LayerNorm(hidden_size)
202
+ self.norm2 = nn.LayerNorm(hidden_size)
203
+ self.dropout1 = nn.Dropout(dropout)
204
+ self.dropout2 = nn.Dropout(dropout)
205
+ self.dropout3 = nn.Dropout(dropout)
206
+
207
+ def forward(
208
+ self,
209
+ hidden_state: torch.Tensor,
210
+ attention_mask: torch.Tensor = None,
211
+ freqs_cos: torch.Tensor = None,
212
+ freqs_sin: torch.Tensor = None,
213
+ return_attn_weights: bool = False,
214
+ is_causal: bool = False,
215
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
216
+ """Forward pass"""
217
+ attn_outputs, attn_weights = self.self_mha(
218
+ hidden_state,
219
+ attn_mask=attention_mask,
220
+ freqs_cos=freqs_cos,
221
+ freqs_sin=freqs_sin,
222
+ return_attn_weights=return_attn_weights,
223
+ is_causal=is_causal,
224
+ )
225
+ x = self.norm1(hidden_state + self.dropout1(attn_outputs))
226
+ ff_output = self.fc2(self.dropout2(self.activation(self.fc1(x))))
227
+ x = self.norm2(x + self.dropout3(ff_output))
228
+ return x, attn_weights
229
+
230
+
231
+ class BacformerTransformerEncoder(nn.Module):
232
+ """Own implementation of Transformer which return attention weights"""
233
+
234
+ def __init__(
235
+ self,
236
+ num_hidden_layers: int,
237
+ hidden_size: int,
238
+ intermediate_size: int,
239
+ num_attention_heads: int,
240
+ dropout: float = 0.1,
241
+ activation: Literal["gelu", "relu"] = "gelu",
242
+ ):
243
+ super().__init__()
244
+
245
+ self.layers = nn.ModuleList(
246
+ [
247
+ BacformerTransformerLayer(
248
+ hidden_size=hidden_size,
249
+ intermediate_size=intermediate_size,
250
+ num_attention_heads=num_attention_heads,
251
+ dropout=dropout,
252
+ activation=activation,
253
+ )
254
+ for _ in range(num_hidden_layers)
255
+ ]
256
+ )
257
+ self.gradient_checkpointing = False
258
+
259
+ def forward(
260
+ self,
261
+ hidden_state: torch.Tensor,
262
+ attention_mask: torch.Tensor = None,
263
+ freqs_cos: torch.Tensor = None,
264
+ freqs_sin: torch.Tensor = None,
265
+ return_attn_weights: bool = False,
266
+ is_causal: bool = False,
267
+ ) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
268
+ """Forward pass"""
269
+ attn_weights_arr = []
270
+ for layer in self.layers:
271
+ if self.gradient_checkpointing and self.training:
272
+ hidden_state, attn_weights = self._gradient_checkpointing_func(
273
+ layer.__call__,
274
+ hidden_state,
275
+ attention_mask,
276
+ freqs_cos,
277
+ freqs_sin,
278
+ return_attn_weights,
279
+ is_causal,
280
+ )
281
+ else:
282
+ hidden_state, attn_weights = layer(
283
+ hidden_state=hidden_state,
284
+ attention_mask=attention_mask,
285
+ freqs_cos=freqs_cos,
286
+ freqs_sin=freqs_sin,
287
+ return_attn_weights=return_attn_weights,
288
+ is_causal=is_causal,
289
+ )
290
+ # keep the attention weights from each layer
291
+ attn_weights_arr.append(attn_weights)
292
+ return hidden_state, attn_weights_arr
293
+
294
+
295
+ class BacformerEmbeddings(nn.Module):
296
+ """Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings."""
297
+
298
+ def __init__(self, config):
299
+ super().__init__()
300
+ self.config = config
301
+ self.linear = nn.Linear(config.hidden_size, config.hidden_size)
302
+
303
+ self.token_type_embeddings = nn.Embedding(
304
+ num_embeddings=config.max_token_type_embeddings + 1,
305
+ embedding_dim=config.hidden_size,
306
+ padding_idx=config.max_token_type_embeddings,
307
+ )
308
+
309
+ self.special_tokens_embeddings = nn.Embedding(
310
+ num_embeddings=config.num_special_tokens,
311
+ embedding_dim=config.hidden_size,
312
+ )
313
+ self.prot_emb_token_id = config.prot_emb_token_id
314
+ self.pad_token_id = config.pad_token_id
315
+
316
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
318
+
319
+ def forward(
320
+ self,
321
+ protein_embeddings: torch.Tensor = None,
322
+ special_tokens_mask: torch.Tensor = None,
323
+ token_type_ids: torch.Tensor = None,
324
+ labels: torch.Tensor = None, # used for causal protein family modeling
325
+ property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property
326
+ ) -> torch.Tensor:
327
+ """Forward pass for protein embeddings."""
328
+ bs, seq_length, dim = protein_embeddings.shape
329
+
330
+ # pass the pooled ESM protein embeddings through a linear layer
331
+ protein_embeddings = self.linear(protein_embeddings)
332
+ protein_embeddings = torch.where(
333
+ special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
334
+ protein_embeddings,
335
+ self.special_tokens_embeddings(special_tokens_mask),
336
+ )
337
+
338
+ if token_type_ids is not None:
339
+ protein_embeddings += self.token_type_embeddings(token_type_ids)
340
+
341
+ protein_embeddings = self.LayerNorm(protein_embeddings)
342
+ protein_embeddings = self.dropout(protein_embeddings)
343
+ return protein_embeddings
344
+
345
+
346
+ class BacformerProteinFamilyEmbeddings(nn.Module):
347
+ """Construct the protein embeddings from protein family tokens, special tokens and sequence type embeddings."""
348
+
349
+ def __init__(
350
+ self,
351
+ config,
352
+ protein_family_embeddings: torch.Tensor = None,
353
+ token_type_embeddings: torch.Tensor = None,
354
+ special_tokens_embeddings: torch.Tensor = None,
355
+ n_conditional_properties: int = None,
356
+ ):
357
+ super().__init__()
358
+ self.config = config
359
+
360
+ if protein_family_embeddings is not None:
361
+ self.protein_family_embeddings = nn.Embedding.from_pretrained(
362
+ protein_family_embeddings,
363
+ freeze=False,
364
+ padding_idx=config.pad_token_id,
365
+ )
366
+ else:
367
+ self.protein_family_embeddings = nn.Embedding(
368
+ num_embeddings=config.protein_clusters_vocab_size + 1,
369
+ embedding_dim=config.hidden_size,
370
+ padding_idx=config.pad_token_id,
371
+ )
372
+
373
+ if token_type_embeddings is not None:
374
+ self.token_type_embeddings = nn.Embedding.from_pretrained(
375
+ token_type_embeddings,
376
+ freeze=False,
377
+ padding_idx=config.max_token_type_embeddings,
378
+ )
379
+ else:
380
+ self.token_type_embeddings = nn.Embedding(
381
+ num_embeddings=config.max_token_type_embeddings + 1,
382
+ embedding_dim=config.hidden_size,
383
+ padding_idx=config.max_token_type_embeddings,
384
+ )
385
+
386
+ if special_tokens_embeddings is not None:
387
+ self.special_tokens_embeddings = nn.Embedding.from_pretrained(
388
+ special_tokens_embeddings,
389
+ freeze=False,
390
+ padding_idx=config.pad_token_id,
391
+ )
392
+ else:
393
+ self.special_tokens_embeddings = nn.Embedding(
394
+ num_embeddings=config.num_special_tokens,
395
+ embedding_dim=config.hidden_size,
396
+ padding_idx=config.pad_token_id,
397
+ )
398
+
399
+ # add layer for conditional properties
400
+ if n_conditional_properties is not None:
401
+ self.conditional_properties_layer = nn.Embedding(n_conditional_properties, config.hidden_size)
402
+
403
+ self.prot_emb_token_id = config.prot_emb_token_id
404
+ self.pad_token_id = config.pad_token_id
405
+
406
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
407
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
408
+
409
+ def forward(
410
+ self,
411
+ protein_embeddings: torch.Tensor = None,
412
+ special_tokens_mask: torch.Tensor = None,
413
+ token_type_ids: torch.Tensor = None,
414
+ labels: torch.Tensor = None, # used for causal protein family modeling
415
+ property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property
416
+ ) -> torch.Tensor:
417
+ """Forward pass for protein embeddings."""
418
+ # pass the pooled ESM protein embeddings through a linear layer
419
+ # replace -100 with pad_token_id
420
+ labels[labels == -100] = self.pad_token_id
421
+ protein_embeddings = self.protein_family_embeddings(labels)
422
+
423
+ bs, seq_length, dim = protein_embeddings.shape
424
+ protein_embeddings = torch.where(
425
+ special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
426
+ protein_embeddings,
427
+ self.special_tokens_embeddings(special_tokens_mask),
428
+ )
429
+
430
+ if token_type_ids is not None:
431
+ protein_embeddings += self.token_type_embeddings(token_type_ids)
432
+
433
+ if property_ids is not None:
434
+ # get the embeddings for the conditional properties
435
+ property_embedding = self.conditional_properties_layer(property_ids).unsqueeze(1)
436
+ # concatenate the protein embeddings with the conditional properties embeddings
437
+ # property embeddings are added to the beginning of the protein embeddings after the CLS token
438
+ protein_embeddings = torch.cat(
439
+ [
440
+ protein_embeddings[:, :1, :], # CLS token
441
+ property_embedding, # conditional properties embeddings
442
+ protein_embeddings[:, 1:, :],
443
+ ], # protein embeddings
444
+ dim=1,
445
+ )
446
+
447
+ protein_embeddings = self.LayerNorm(protein_embeddings)
448
+ protein_embeddings = self.dropout(protein_embeddings)
449
+ return protein_embeddings
450
+
451
+
452
+ class BacformerEncoder(nn.Module):
453
+ """Bacformer encoder model"""
454
+
455
+ def __init__(self, config):
456
+ super().__init__()
457
+ self.config = config
458
+
459
+ self.encoder = BacformerTransformerEncoder(
460
+ num_hidden_layers=config.num_hidden_layers,
461
+ hidden_size=config.hidden_size,
462
+ num_attention_heads=config.num_attention_heads,
463
+ intermediate_size=config.intermediate_size,
464
+ activation="gelu",
465
+ dropout=config.attention_probs_dropout_prob,
466
+ )
467
+
468
+ # Note that config.max_position_embeddings is multiplied by 1.5 because the token limit for the Bacformer of
469
+ # models is 6000. Adding this multiplier instead of using 6000 directly allows for dynamism of token
470
+ # lengths while training or fine-tuning.
471
+ freqs_cos, freqs_sin = precompute_freqs_cis(
472
+ config.hidden_size // config.num_attention_heads, int(config.max_position_embeddings * 1.5)
473
+ )
474
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
475
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
476
+
477
+ def forward(
478
+ self,
479
+ hidden_states: torch.Tensor,
480
+ attention_mask: torch.Tensor = None,
481
+ return_attn_weights: Union[bool, None] = None,
482
+ is_causal: bool = False,
483
+ ) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
484
+ """Pass the input through the encoder layers in turn.
485
+
486
+ Args:
487
+ hidden_states: hidden states from the BacformerEmbeddings layer
488
+ attention_mask: mask for the attention in the transformer
489
+ """
490
+ return_attn_weights = (
491
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
492
+ )
493
+ bs, seq_len, _ = hidden_states.shape
494
+ last_hidden_state, attn_weights = self.encoder(
495
+ hidden_state=hidden_states,
496
+ attention_mask=attention_mask,
497
+ freqs_cos=self.freqs_cos[:seq_len, :],
498
+ freqs_sin=self.freqs_sin[:seq_len, :],
499
+ return_attn_weights=return_attn_weights,
500
+ is_causal=is_causal,
501
+ )
502
+ return last_hidden_state, attn_weights
503
+
504
+
505
+ class BacformerPreTrainedModel(PreTrainedModel):
506
+ """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models."""
507
+
508
+ config_class = BacformerConfig
509
+ base_model_prefix = "bacformer"
510
+ supports_gradient_checkpointing = True
511
+ _no_split_modules = ["BacformerEmbeddings", "BacformerTransformerLayer"]
512
+
513
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
514
+ def _init_weights(self, module):
515
+ """Initialize the weights"""
516
+ if isinstance(module, nn.Linear):
517
+ # Slightly different from the TF version which uses truncated_normal for initialization
518
+ # cf https://github.com/pytorch/pytorch/pull/5617
519
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
520
+ if module.bias is not None:
521
+ module.bias.data.zero_()
522
+ elif isinstance(module, nn.Embedding):
523
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
524
+ if module.padding_idx is not None:
525
+ module.weight.data[module.padding_idx].zero_()
526
+ elif isinstance(module, nn.LayerNorm):
527
+ module.bias.data.zero_()
528
+ module.weight.data.fill_(1.0)
529
+
530
+
531
+ class BacformerModel(BacformerPreTrainedModel):
532
+ """Bacformer model."""
533
+
534
+ def __init__(self, config: BacformerConfig, add_pooling_layer: bool = False):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = BacformerEmbeddings(config)
539
+ self.encoder = BacformerEncoder(config)
540
+
541
+ self.pooler = BacformerPooler(config) if add_pooling_layer else None
542
+
543
+ # Initialize weights and apply final processing
544
+ self.post_init()
545
+
546
+ def forward(
547
+ self,
548
+ protein_embeddings: torch.Tensor = None,
549
+ special_tokens_mask: torch.Tensor = None,
550
+ token_type_ids: torch.Tensor = None,
551
+ attention_mask: torch.Tensor = None,
552
+ labels: torch.Tensor = None,
553
+ property_ids: torch.Tensor = None,
554
+ return_attn_weights: bool = False,
555
+ return_dict: Union[bool, None] = None,
556
+ is_causal: bool = False,
557
+ ) -> Optional[BacformerModelOutput]:
558
+ """Forward method for the model."""
559
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
560
+ # get embeddings
561
+ protein_embeddings = self.embeddings(
562
+ protein_embeddings=protein_embeddings,
563
+ labels=labels,
564
+ special_tokens_mask=special_tokens_mask,
565
+ token_type_ids=token_type_ids,
566
+ property_ids=property_ids,
567
+ )
568
+
569
+ # create 3D attention mask from 2D if not doing causal GM
570
+ if attention_mask is not None and not is_causal:
571
+ attention_mask = create_4d_from_2d_attn_mask(
572
+ attn_mask=attention_mask, num_attn_heads=self.config.num_attention_heads
573
+ ).bool()
574
+
575
+ last_hidden_state, attentions = self.encoder(
576
+ hidden_states=protein_embeddings,
577
+ attention_mask=attention_mask,
578
+ return_attn_weights=return_attn_weights,
579
+ is_causal=is_causal,
580
+ )
581
+ pooler_output = (
582
+ self.pooler(hidden_states=last_hidden_state, padding_mask=attention_mask)
583
+ if self.pooler is not None
584
+ else None
585
+ )
586
+
587
+ if not return_dict:
588
+ return (last_hidden_state, pooler_output, attentions)
589
+
590
+ return BacformerModelOutput(
591
+ last_hidden_state=last_hidden_state,
592
+ pooler_output=pooler_output,
593
+ attentions=attentions,
594
+ )
595
+
596
+
597
+ class BacformerForCausalGM(BacformerPreTrainedModel):
598
+ """Bacformer model with genomic modeling head on top"""
599
+
600
+ _tied_weights_keys = ["gm_head.decoder.weight"]
601
+
602
+ def __init__(self, config: BacformerConfig):
603
+ super().__init__(config)
604
+ self.config = config
605
+
606
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
607
+ self.gm_head = BacformerGMHead(config)
608
+
609
+ # Initialize weights
610
+ self.init_weights()
611
+
612
+ def forward(
613
+ self,
614
+ protein_embeddings: torch.Tensor,
615
+ special_tokens_mask: torch.Tensor,
616
+ labels: torch.Tensor = None,
617
+ token_type_ids: torch.Tensor = None,
618
+ attention_mask: torch.Tensor = None,
619
+ return_attn_weights: bool = None,
620
+ return_dict: Union[bool, None] = None,
621
+ ) -> Optional[BacformerModelOutput]:
622
+ """Forward method for the model."""
623
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
624
+ return_attn_weights = (
625
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
626
+ )
627
+
628
+ outputs = self.bacformer(
629
+ protein_embeddings=protein_embeddings,
630
+ special_tokens_mask=special_tokens_mask,
631
+ token_type_ids=token_type_ids,
632
+ attention_mask=None, # attention mechanism handles the causal mask
633
+ return_attn_weights=return_attn_weights,
634
+ return_dict=return_dict,
635
+ is_causal=True,
636
+ )
637
+ last_hidden_state = outputs[0]
638
+ prediction_scores = self.gm_head(last_hidden_state)
639
+
640
+ loss = None
641
+ if labels is not None:
642
+ labels = labels.to(prediction_scores.device)
643
+
644
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1])
645
+ labels = labels[:, 1:].contiguous().view(-1)
646
+ loss = cross_entropy(shifted_prediction_scores, labels)
647
+
648
+ if not return_dict:
649
+ return (
650
+ loss,
651
+ prediction_scores,
652
+ ) + outputs
653
+
654
+ return BacformerModelOutput(
655
+ loss=loss,
656
+ logits=prediction_scores,
657
+ last_hidden_state=outputs.last_hidden_state,
658
+ attentions=outputs.attentions,
659
+ )
660
+
661
+
662
+ class BacformerForMaskedGM(BacformerPreTrainedModel):
663
+ """Bacformer model with genomic modeling head on top"""
664
+
665
+ _tied_weights_keys = ["gm_head.decoder.weight"]
666
+
667
+ def __init__(self, config: BacformerConfig):
668
+ super().__init__(config)
669
+ self.config = config
670
+
671
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
672
+ self.gm_head = BacformerGMHead(config)
673
+
674
+ # Initialize weights
675
+ self.init_weights()
676
+
677
+ def forward(
678
+ self,
679
+ protein_embeddings: torch.Tensor,
680
+ special_tokens_mask: torch.Tensor,
681
+ labels: torch.Tensor = None,
682
+ token_type_ids: torch.Tensor = None,
683
+ attention_mask: torch.Tensor = None,
684
+ return_attn_weights: bool = None,
685
+ return_dict: Union[bool, None] = None,
686
+ ) -> Union[BacformerModelOutput, None]:
687
+ """Forward method for the model."""
688
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
689
+ return_attn_weights = (
690
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
691
+ )
692
+
693
+ outputs = self.bacformer(
694
+ protein_embeddings=protein_embeddings,
695
+ special_tokens_mask=special_tokens_mask,
696
+ token_type_ids=token_type_ids,
697
+ attention_mask=attention_mask,
698
+ return_attn_weights=return_attn_weights,
699
+ return_dict=return_dict,
700
+ )
701
+ last_hidden_state = outputs[0]
702
+
703
+ # to speed up the forward pass, let's only consider the masked tokens
704
+
705
+ loss = None
706
+ if labels is not None:
707
+ # to speed up the forward pass, let's only consider the masked tokens
708
+ last_hidden_state = last_hidden_state[labels != -100]
709
+ prediction_scores = self.gm_head(last_hidden_state)
710
+ labels = labels.to(prediction_scores.device)
711
+
712
+ ### notes
713
+ # use the labels to get -100 for non-masked tokens
714
+ # do not use special_tokens_mask
715
+ # check how the labels are constructed
716
+
717
+ # only considering the masked tokens
718
+ labels = labels[labels != -100]
719
+ loss = cross_entropy(prediction_scores, labels)
720
+ else:
721
+ prediction_scores = self.gm_head(last_hidden_state)
722
+
723
+ if not return_dict:
724
+ return (
725
+ loss,
726
+ prediction_scores,
727
+ ) + outputs
728
+
729
+ return BacformerModelOutput(
730
+ loss=loss,
731
+ logits=prediction_scores,
732
+ last_hidden_state=outputs.last_hidden_state,
733
+ attentions=outputs.attentions,
734
+ )
735
+
736
+
737
+ class BacformerForCausalProteinFamilyModeling(BacformerPreTrainedModel):
738
+ """Bacformer model for causal modeling of protein families. Using protein family as tokens rather than protein embeddings"""
739
+
740
+ _tied_weights_keys = ["gm_head.decoder.weight"]
741
+
742
+ def __init__(
743
+ self,
744
+ config: BacformerConfig,
745
+ n_conditional_properties: int = None,
746
+ initialise_from_non_pfm_model: bool = False,
747
+ ):
748
+ super().__init__(config)
749
+ self.config = config
750
+ self.cls_token_id = SPECIAL_TOKENS_DICT["CLS"]
751
+
752
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
753
+ self.gm_head = BacformerGMHead(config)
754
+
755
+ if initialise_from_non_pfm_model:
756
+ # Initialize weights
757
+ self.init_weights()
758
+ # overwrite the embeddings with the pretrained
759
+ # protein family embeddings from the decoder of the GM Head
760
+ self.bacformer.embeddings = BacformerProteinFamilyEmbeddings(
761
+ config,
762
+ protein_family_embeddings=self.gm_head.decoder.weight,
763
+ token_type_embeddings=self.bacformer.embeddings.token_type_embeddings.weight,
764
+ special_tokens_embeddings=self.bacformer.embeddings.special_tokens_embeddings.weight,
765
+ n_conditional_properties=n_conditional_properties,
766
+ )
767
+ else:
768
+ self.bacformer.embeddings = BacformerProteinFamilyEmbeddings(
769
+ config,
770
+ n_conditional_properties=n_conditional_properties,
771
+ )
772
+ self.init_weights()
773
+
774
+ def forward(
775
+ self,
776
+ labels: torch.Tensor = None,
777
+ special_tokens_mask: torch.Tensor = None,
778
+ token_type_ids: torch.Tensor = None,
779
+ property_ids: torch.Tensor = None,
780
+ return_attn_weights: bool = None,
781
+ return_dict: Union[bool, None] = None,
782
+ ) -> Optional[BacformerModelOutput]:
783
+ """Forward method for the model."""
784
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
785
+ return_attn_weights = (
786
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
787
+ )
788
+
789
+ outputs = self.bacformer(
790
+ protein_embeddings=None,
791
+ labels=labels,
792
+ special_tokens_mask=special_tokens_mask,
793
+ token_type_ids=token_type_ids,
794
+ property_ids=property_ids,
795
+ return_attn_weights=return_attn_weights,
796
+ return_dict=return_dict,
797
+ is_causal=True,
798
+ )
799
+ last_hidden_state = outputs[0]
800
+ prediction_scores = self.gm_head(last_hidden_state)
801
+
802
+ loss = None
803
+ if labels is not None:
804
+ if property_ids is not None:
805
+ labels = torch.cat(
806
+ [
807
+ torch.tensor([-100], dtype=torch.long)
808
+ .unsqueeze(0)
809
+ .to(labels.device), # account for the property token
810
+ labels,
811
+ ],
812
+ dim=1,
813
+ ) # ignore index
814
+ labels = labels.to(prediction_scores.device)
815
+
816
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1])
817
+ labels = labels[:, 1:].contiguous().view(-1)
818
+ loss = cross_entropy(shifted_prediction_scores, labels)
819
+
820
+ if not return_dict:
821
+ return (
822
+ loss,
823
+ prediction_scores,
824
+ ) + outputs
825
+
826
+ return BacformerModelOutput(
827
+ loss=loss,
828
+ logits=prediction_scores,
829
+ last_hidden_state=outputs.last_hidden_state,
830
+ attentions=outputs.attentions,
831
+ )
832
+
833
+ def generate(
834
+ self,
835
+ protein_family_ids: torch.LongTensor,
836
+ special_tokens_mask: torch.LongTensor = None,
837
+ token_type_ids: torch.LongTensor = None,
838
+ max_length: int = 6000,
839
+ end_token_id: int = 50000,
840
+ do_sample: bool = False,
841
+ top_k: int = 50,
842
+ top_p: float = 1.0,
843
+ temperature: float = 1.0,
844
+ property_ids: torch.LongTensor = None,
845
+ return_last_hidden_states: bool = False,
846
+ ):
847
+ """
848
+ Generate a sequence of tokens autoregressively from a given prompt.
849
+
850
+ Args:
851
+ protein_family_ids (torch.LongTensor): Tensor of shape (batch, seq_len) with token indices.
852
+ max_length (int): Maximum length of the generated sequence (prompt + newly generated).
853
+ end_token_id (int, optional): Token ID signifying end-of-sequence (END).
854
+ If encountered, generation stops.
855
+ do_sample (bool): Whether to sample from the probability distribution (True)
856
+ or use greedy decoding (False).
857
+ top_k (int): If >0, use top-k filtering in sampling mode.
858
+ top_p (float): If <1.0, use nucleus (top-p) filtering in sampling mode.
859
+ temperature (float): Softmax temperature for scaling logits.
860
+ Higher => more random, lower => more deterministic.
861
+ return_last_hidden_states (bool): If True, return final hidden states as well.
862
+
863
+ Returns
864
+ -------
865
+ torch.LongTensor: The generated token sequence of shape (batch, final_seq_len).
866
+ (Optional) torch.FloatTensor: Final hidden states of shape (batch, final_seq_len, hidden_dim)
867
+ if `return_hidden_states=True`.
868
+ """
869
+ # Default END token
870
+ if end_token_id is None:
871
+ end_token_id = getattr(self, "end_token_id", None)
872
+
873
+ # Switch to eval mode and move input to correct device
874
+ self.eval()
875
+ device = next(self.parameters()).device
876
+ protein_family_ids = protein_family_ids.to(device)
877
+
878
+ # create a special tokens mask if not provided
879
+ if special_tokens_mask is None:
880
+ # add a cls token at the beginning
881
+ protein_family_ids = torch.cat(
882
+ [torch.tensor([[-100]]).to(device), protein_family_ids],
883
+ dim=1,
884
+ )
885
+ special_tokens_mask = [self.cls_token_id] + [self.config.prot_emb_token_id] * (
886
+ protein_family_ids.shape[1] - 1
887
+ )
888
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.long).to(device)
889
+
890
+ # create a token type mask if not provided
891
+ if token_type_ids is None:
892
+ token_type_ids = torch.zeros_like(protein_family_ids)
893
+
894
+ # Prepare the initial sequence and define max new tokens
895
+ generated = protein_family_ids.clone()
896
+ batch_size, prompt_length = generated.shape
897
+ max_new_tokens = max_length - prompt_length
898
+ if max_new_tokens <= 0:
899
+ max_new_tokens = 0
900
+
901
+ # Disable gradient calculations for generation
902
+ with torch.no_grad():
903
+ for _step in range(max_new_tokens):
904
+ # Forward pass
905
+ logits = self.forward(
906
+ labels=generated,
907
+ special_tokens_mask=special_tokens_mask,
908
+ # assume it's all on one chromosome
909
+ token_type_ids=token_type_ids,
910
+ property_ids=property_ids,
911
+ return_dict=True,
912
+ ).logits
913
+ # Focus on the last token's logits
914
+ next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
915
+
916
+ # Apply temperature
917
+ if temperature != 1.0:
918
+ next_token_logits = next_token_logits / temperature
919
+
920
+ # Sampling or greedy?
921
+ if do_sample:
922
+ # Top-k filter
923
+ next_token_logits = top_k_filtering(next_token_logits, top_k=top_k)
924
+ # Top-p filter
925
+ next_token_logits = top_p_filtering(next_token_logits, top_p=top_p)
926
+
927
+ probs = softmax(next_token_logits, dim=-1)
928
+ next_token_id = torch.multinomial(probs, num_samples=1)
929
+ else:
930
+ # Greedy decoding
931
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
932
+
933
+ # Append predicted token
934
+ generated = torch.cat([generated, next_token_id], dim=1)
935
+ special_tokens_mask = torch.cat(
936
+ [special_tokens_mask, torch.tensor([[self.config.prot_emb_token_id]]).to(generated.device)], dim=1
937
+ )
938
+ last_token_type_id = token_type_ids[:, -1].unsqueeze(1)
939
+ token_type_ids = torch.cat([token_type_ids, last_token_type_id], dim=1)
940
+
941
+ # Check for END in all sequences
942
+ if end_token_id is not None:
943
+ if (next_token_id.squeeze(1) == end_token_id).all():
944
+ # If every sequence ended, break early
945
+ break
946
+
947
+ if not return_last_hidden_states:
948
+ return generated
949
+
950
+ # Optionally compute final hidden states
951
+ if return_last_hidden_states:
952
+ last_hidden_state = self.forward(
953
+ labels=generated,
954
+ special_tokens_mask=special_tokens_mask,
955
+ token_type_ids=token_type_ids,
956
+ return_dict=True,
957
+ ).last_hidden_state
958
+
959
+ return generated, last_hidden_state
960
+
961
+
962
+ class BacformerForMaskedGMWithContrastiveLoss(BacformerPreTrainedModel):
963
+ """Bacformer model with genomic modeling head on top"""
964
+
965
+ _tied_weights_keys = ["gm_head.decoder.weight"]
966
+
967
+ def __init__(self, config: BacformerConfig):
968
+ super().__init__(config)
969
+ self.config = config
970
+
971
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
972
+ self.gm_head = BacformerGMHead(config)
973
+
974
+ # Initialize weights
975
+ self.init_weights()
976
+
977
+ def forward(
978
+ self,
979
+ protein_embeddings: torch.Tensor,
980
+ special_tokens_mask: torch.Tensor,
981
+ labels: torch.Tensor = None,
982
+ token_type_ids: torch.Tensor = None,
983
+ attention_mask: torch.Tensor = None,
984
+ return_attn_weights: bool = None,
985
+ return_dict: Union[bool, None] = None,
986
+ ) -> Union[BacformerModelOutput, None]:
987
+ """Forward method for the model."""
988
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
989
+ return_attn_weights = (
990
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
991
+ )
992
+
993
+ outputs = self.bacformer(
994
+ protein_embeddings=protein_embeddings,
995
+ special_tokens_mask=special_tokens_mask,
996
+ token_type_ids=token_type_ids,
997
+ attention_mask=attention_mask,
998
+ return_attn_weights=return_attn_weights,
999
+ return_dict=return_dict,
1000
+ )
1001
+ last_hidden_state = outputs[0]
1002
+
1003
+ # to speed up the forward pass, let's only consider the masked tokens
1004
+
1005
+ loss = None
1006
+ if labels is not None:
1007
+ # contrastive loss
1008
+ contrastive_loss = compute_contrastive_loss(protein_embeddings, last_hidden_state, special_tokens_mask)
1009
+ # to speed up the forward pass, let's only consider the masked tokens
1010
+ last_hidden_state = last_hidden_state[labels != -100]
1011
+ prediction_scores = self.gm_head(last_hidden_state)
1012
+ labels = labels.to(prediction_scores.device)
1013
+
1014
+ # only considering the masked tokens
1015
+ labels = labels[labels != -100]
1016
+ masked_loss = cross_entropy(prediction_scores, labels)
1017
+ loss = masked_loss + self.config.alpha_contrastive_loss * contrastive_loss
1018
+ else:
1019
+ prediction_scores = self.gm_head(last_hidden_state)
1020
+
1021
+ if not return_dict:
1022
+ return (
1023
+ loss,
1024
+ prediction_scores,
1025
+ ) + outputs
1026
+
1027
+ return BacformerModelOutput(
1028
+ loss=loss,
1029
+ logits=prediction_scores,
1030
+ last_hidden_state=outputs.last_hidden_state,
1031
+ attentions=outputs.attentions,
1032
+ )
1033
+
1034
+
1035
+ class BacformerForProteinClassification(BacformerPreTrainedModel):
1036
+ """Bacformer model with a classification head on top for protein classification tasks."""
1037
+
1038
+ def __init__(self, config: BacformerConfig, benchmark_esm: bool = False):
1039
+ super().__init__(config)
1040
+ self.config = config
1041
+ self.benchmark_esm = benchmark_esm
1042
+
1043
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1044
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1045
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1046
+
1047
+ # Initialize weights and apply final processing
1048
+ self.post_init()
1049
+
1050
+ def forward(
1051
+ self,
1052
+ protein_embeddings: torch.Tensor,
1053
+ special_tokens_mask: torch.Tensor,
1054
+ labels: torch.Tensor = None,
1055
+ token_type_ids: torch.Tensor = None,
1056
+ attention_mask: torch.Tensor = None,
1057
+ return_attn_weights: bool = None,
1058
+ return_dict: Union[bool, None] = None,
1059
+ ) -> Optional[BacformerModelOutput]:
1060
+ """Forward method for the model."""
1061
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1062
+ return_attn_weights = (
1063
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1064
+ )
1065
+
1066
+ if self.benchmark_esm:
1067
+ outputs = [protein_embeddings]
1068
+ else:
1069
+ outputs = self.bacformer(
1070
+ protein_embeddings=protein_embeddings,
1071
+ special_tokens_mask=special_tokens_mask,
1072
+ token_type_ids=token_type_ids,
1073
+ attention_mask=attention_mask,
1074
+ return_attn_weights=return_attn_weights,
1075
+ return_dict=return_dict,
1076
+ )
1077
+
1078
+ last_hidden_state = outputs[0]
1079
+
1080
+ last_hidden_state = self.dropout(last_hidden_state)
1081
+ logits = self.classifier(last_hidden_state)
1082
+
1083
+ loss = None
1084
+ if labels is not None:
1085
+ labels = labels.to(logits.device)
1086
+
1087
+ if self.config.problem_type == "regression":
1088
+ loss = mse_loss(logits, labels)
1089
+ elif self.config.problem_type == "single_label_classification":
1090
+ loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1091
+ elif (
1092
+ self.config.problem_type == "multi_label_classification"
1093
+ or self.config.problem_type == "binary_classification"
1094
+ ):
1095
+ # remove the -100 labels from loss computation
1096
+ mask = torch.ones_like(labels.view(-1)) - (labels.view(-1) == -100.0).float()
1097
+ loss = binary_cross_entropy_with_logits(
1098
+ logits.view(-1), labels.view(-1).type_as(logits), reduction="none"
1099
+ )
1100
+ loss = (loss * mask).sum() / mask.sum()
1101
+
1102
+ if not return_dict:
1103
+ return (
1104
+ loss,
1105
+ None,
1106
+ logits,
1107
+ ) # + outputs
1108
+
1109
+ return BacformerModelOutput(
1110
+ loss=loss,
1111
+ logits=logits,
1112
+ last_hidden_state=last_hidden_state,
1113
+ attentions=outputs.attentions,
1114
+ )
1115
+
1116
+
1117
+ class BacformerForGenomeClassification(BacformerPreTrainedModel):
1118
+ """Bacformer model with a classification head on top for genome classification tasks."""
1119
+
1120
+ def __init__(self, config: BacformerConfig):
1121
+ super().__init__(config)
1122
+ self.config = config
1123
+
1124
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1125
+ self.classifier = BacformerGenomeClassificationHead(config)
1126
+
1127
+ # Initialize weights and apply final processing
1128
+ self.post_init()
1129
+
1130
+ def forward(
1131
+ self,
1132
+ protein_embeddings: torch.Tensor,
1133
+ special_tokens_mask: torch.Tensor,
1134
+ labels: torch.Tensor = None,
1135
+ token_type_ids: torch.Tensor = None,
1136
+ attention_mask: torch.Tensor = None,
1137
+ return_attn_weights: bool = None,
1138
+ return_dict: Union[bool, None] = None,
1139
+ ) -> Optional[BacformerModelOutput]:
1140
+ """Forward method for the model."""
1141
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1142
+ return_attn_weights = (
1143
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1144
+ )
1145
+
1146
+ outputs = self.bacformer(
1147
+ protein_embeddings=protein_embeddings,
1148
+ special_tokens_mask=special_tokens_mask,
1149
+ token_type_ids=token_type_ids,
1150
+ attention_mask=attention_mask,
1151
+ return_attn_weights=return_attn_weights,
1152
+ return_dict=return_dict,
1153
+ )
1154
+ last_hidden_state = outputs[0]
1155
+ logits = self.classifier(last_hidden_state, attention_mask)
1156
+
1157
+ loss = None
1158
+ if labels is not None:
1159
+ labels = labels.to(logits.device)
1160
+
1161
+ if self.config.problem_type == "regression":
1162
+ loss = mse_loss(logits.view(-1), labels.view(-1))
1163
+ elif self.config.problem_type == "binary_classification":
1164
+ loss = binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
1165
+ elif self.config.problem_type == "single_label_classification":
1166
+ loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1167
+ elif self.config.problem_type == "multi_label_classification":
1168
+ loss = binary_cross_entropy_with_logits(logits, labels)
1169
+
1170
+ if not return_dict:
1171
+ return (
1172
+ loss,
1173
+ None,
1174
+ logits,
1175
+ )
1176
+
1177
+ return BacformerModelOutput(
1178
+ loss=loss,
1179
+ logits=logits,
1180
+ last_hidden_state=outputs.last_hidden_state,
1181
+ attentions=outputs.attentions,
1182
+ )
1183
+
1184
+
1185
+ class BacformerForProteinProteinInteraction(BacformerPreTrainedModel):
1186
+ """Bacformer model with a protein-protein interaction head on top."""
1187
+
1188
+ def __init__(self, config: BacformerConfig, benchmark_esm: bool = False):
1189
+ super().__init__(config)
1190
+ self.config = config
1191
+ self.benchmark_esm = benchmark_esm
1192
+ print("Benchmark ESM:", self.benchmark_esm)
1193
+ self.return_attn_weights = config.return_attn_weights
1194
+
1195
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1196
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1197
+ self.dense = nn.Sequential(
1198
+ nn.Linear(config.hidden_size, config.hidden_size),
1199
+ nn.GELU(),
1200
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1201
+ nn.Dropout(0.2),
1202
+ )
1203
+ self.ppi_head = BacformerProteinProteinInteractionHead(
1204
+ in_features=config.hidden_size, prot_emb_idx=config.prot_emb_token_id
1205
+ )
1206
+
1207
+ # Initialize weights and apply final processing
1208
+ self.post_init()
1209
+
1210
+ def forward(
1211
+ self,
1212
+ protein_embeddings: torch.Tensor,
1213
+ special_tokens_mask: torch.Tensor,
1214
+ labels: torch.Tensor = None,
1215
+ token_type_ids: torch.Tensor = None,
1216
+ attention_mask: torch.Tensor = None,
1217
+ return_attn_weights: bool = None,
1218
+ return_dict: Union[bool, None] = None,
1219
+ ) -> Union[OrderedDict, None]: # TODO: change it from token classifier output
1220
+ """Forward method for the model."""
1221
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1222
+
1223
+ if self.benchmark_esm:
1224
+ last_hidden_state = protein_embeddings.squeeze(0)[1:-2, :]
1225
+ else:
1226
+ outputs = self.bacformer(
1227
+ protein_embeddings=protein_embeddings,
1228
+ special_tokens_mask=special_tokens_mask,
1229
+ token_type_ids=token_type_ids,
1230
+ attention_mask=attention_mask,
1231
+ return_attn_weights=False,
1232
+ return_dict=True,
1233
+ )
1234
+ last_hidden_state = outputs.last_hidden_state.squeeze(0)[1:-2, :]
1235
+
1236
+ assert labels.shape[0] == 1, "Batch size should be 1 for protein-protein interaction task"
1237
+
1238
+ last_hidden_state = self.dense(self.dropout(last_hidden_state))
1239
+ last_hidden_state = torch.cat([last_hidden_state[labels[:, 0]], last_hidden_state[labels[:, 1]]], dim=0).mean(
1240
+ dim=0
1241
+ )
1242
+ logits = self.ppi_head(last_hidden_state)
1243
+
1244
+ loss = binary_cross_entropy_with_logits(logits, labels[:, 2].type_as(logits).squeeze(0))
1245
+
1246
+ if not return_dict:
1247
+ return (
1248
+ loss,
1249
+ logits,
1250
+ )
1251
+
1252
+ return BacformerModelOutput(
1253
+ loss=loss,
1254
+ logits=logits,
1255
+ last_hidden_state=outputs.last_hidden_state,
1256
+ attentions=outputs.attentions,
1257
+ )
1258
+
1259
+
1260
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
1261
+ class BacformerPooler(nn.Module):
1262
+ """Pooler for Bacformer model."""
1263
+
1264
+ def __init__(self, config: BacformerConfig):
1265
+ super().__init__()
1266
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1267
+ self.activation = nn.Tanh()
1268
+
1269
+ def forward(self, hidden_states: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor:
1270
+ """Forward method for the pooler."""
1271
+ # We "pool" the model by taking the mean of non-padding tokens
1272
+ padding_mask = padding_mask.to(hidden_states.device) if padding_mask is not None else None
1273
+ if padding_mask is not None:
1274
+ mean_hidden_states = torch.einsum("ijk,ij->ik", hidden_states, padding_mask) / padding_mask.sum(
1275
+ 1
1276
+ ).unsqueeze(1)
1277
+ else:
1278
+ mean_hidden_states = hidden_states.mean(dim=1)
1279
+ pooled_output = self.dense(mean_hidden_states)
1280
+ pooled_output = self.activation(pooled_output)
1281
+ return pooled_output
1282
+
1283
+
1284
+ class BacformerGMHead(nn.Module):
1285
+ """Bacformer Head for genomic modeling."""
1286
+
1287
+ def __init__(self, config):
1288
+ super().__init__()
1289
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1290
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1291
+
1292
+ # add 1 to the condfig.protein_clusters_vocab_size to account for the end token
1293
+ self.decoder = nn.Linear(config.hidden_size, config.protein_clusters_vocab_size + 1, bias=False)
1294
+ self.bias = nn.Parameter(torch.zeros(config.protein_clusters_vocab_size + 1))
1295
+
1296
+ def forward(self, features, **kwargs):
1297
+ """Forward method for the head."""
1298
+ x = self.dense(features)
1299
+ x = gelu(x)
1300
+ x = self.layer_norm(x)
1301
+
1302
+ # project back to nr of labels with bias
1303
+ x = self.decoder(x) + self.bias
1304
+ return x
1305
+
1306
+
1307
+ class BacformerGenomeClassificationHead(nn.Module):
1308
+ """Head for genome-level classification tasks."""
1309
+
1310
+ def __init__(self, config: BacformerConfig):
1311
+ super().__init__()
1312
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1313
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1314
+
1315
+ def forward(self, features: torch.Tensor, padding_mask: torch.Tensor, **kwargs):
1316
+ """Forward method for the head."""
1317
+ if padding_mask is not None:
1318
+ x = torch.einsum("ijk,ij->ik", features, padding_mask) / padding_mask.sum(1).unsqueeze(1)
1319
+ else:
1320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1321
+ x = self.dropout(x)
1322
+ x = self.out_proj(x)
1323
+ return x
1324
+
1325
+
1326
+ class BacformerProteinProteinInteractionHead(nn.Module):
1327
+ """Head for protein-protein interaction task at a genome level."""
1328
+
1329
+ def __init__(self, in_features: int, prot_emb_idx: int = 4, bias: bool = True):
1330
+ super().__init__()
1331
+ self.in_features = in_features
1332
+ self.prot_emb_idx = prot_emb_idx
1333
+ self.dropout = nn.Dropout(0.2)
1334
+ self.linear = nn.Linear(in_features, 1, bias=bias)
1335
+
1336
+ def forward(
1337
+ self, hidden_states: torch.Tensor
1338
+ ) -> torch.Tensor: # special_tokens_mask: torch.Tensor, attentions: torch.Tensor):
1339
+ """Forward method for the head."""
1340
+ return self.linear(self.dropout(hidden_states)).squeeze(-1)
utils_bacformer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import cross_entropy, softmax
3
+
4
+ from .configuration_bacformer import SPECIAL_TOKENS_DICT
5
+
6
+
7
+ def compute_contrastive_loss(
8
+ protein_embeddings: torch.Tensor,
9
+ last_hidden_state: torch.Tensor,
10
+ special_tokens_mask: torch.Tensor,
11
+ ) -> torch.Tensor:
12
+ """Compute contrastive loss between protein embeddings and masked items."""
13
+ # keep protein embeddings and masked items
14
+ # ensure the batch size is 1, the model currently does not work with batch size > 1
15
+ assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1
16
+
17
+ # subset to mask and protein embedding tokens
18
+ special_tokens_mask = special_tokens_mask.squeeze(0)
19
+ mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
20
+ special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
21
+ )
22
+ protein_embeddings = protein_embeddings.squeeze(0)[mask]
23
+ last_hidden_state = last_hidden_state.squeeze(0)[mask]
24
+
25
+ # Normalize embeddings
26
+ last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
27
+ protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)
28
+
29
+ # Compute similarity matrix and loss as before
30
+ similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)
31
+
32
+ n_prots = protein_embeddings.shape[0]
33
+ labels = torch.arange(n_prots).to(protein_embeddings.device)
34
+
35
+ # Compute the loss
36
+ loss = cross_entropy(similarity_matrix, labels)
37
+ return loss
38
+
39
+
40
+ def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
41
+ """
42
+ Keep only top_k logits and set the rest to -inf.
43
+
44
+ Args:
45
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
46
+ top_k (int): The number of highest probability logits to keep.
47
+
48
+ Returns
49
+ -------
50
+ torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
51
+ """
52
+ if top_k <= 0:
53
+ return logits
54
+
55
+ # Find top_k values
56
+ top_k = min(top_k, logits.size(-1))
57
+ vals, idx = torch.topk(logits, top_k, dim=-1)
58
+ # Get the smallest logit in the top_k
59
+ min_vals = vals[:, -1].unsqueeze(-1)
60
+ # Mask all logits that are < this min value
61
+ mask = logits < min_vals
62
+ logits[mask] = float("-inf")
63
+ return logits
64
+
65
+
66
+ def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
67
+ """
68
+ Keep the smallest set of logits whose cumulative probability >= top_p.
69
+
70
+ Args:
71
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
72
+ top_p (float): Cumulative probability threshold.
73
+
74
+ Returns
75
+ -------
76
+ torch.Tensor: Filtered logits where only tokens within the top_p cumulative
77
+ probability mass are kept; the rest are set to -inf.
78
+ """
79
+ if top_p >= 1.0:
80
+ return logits
81
+
82
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
83
+ cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
84
+
85
+ # Identify where cumulative probability exceeds top_p
86
+ sorted_indices_to_remove = cumulative_probs > top_p
87
+ # Shift the mask to ensure we always keep at least one token
88
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
89
+ sorted_indices_to_remove[..., 0] = False
90
+
91
+ # Scatter to replicate the mask in the original ordering
92
+ for i in range(logits.size(0)):
93
+ remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
94
+ logits[i, remove_indices] = float("-inf")
95
+
96
+ return logits
97
+
98
+
99
+ def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
100
+ """Helper function to reshape attn_mask to 3D from 2D"""
101
+ assert (
102
+ len(attn_mask.shape) == 2
103
+ ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"
104
+
105
+ bs, seq_len = attn_mask.shape
106
+ attn_mask = attn_mask.view(bs, 1, 1, seq_len)
107
+ attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
108
+ attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
109
+ return attn_mask