macwiatrak commited on
Commit
7a26e46
·
verified ·
1 Parent(s): 0fed78d

Upload BacformerForCausalGM

Browse files
Files changed (3) hide show
  1. config.json +7 -202
  2. configuration_bacformer.py +72 -0
  3. modeling_bacformer.py +1461 -0
config.json CHANGED
@@ -1,10 +1,13 @@
1
  {
2
- "_name_or_path": "/rds/user/mw896/rds-flotolab-9X9gY1OFt4M/projects/bacformer/output-data/all-genomes/runs-causal/12L-full-data-rotary-lr15e-5/checkpoint-176000/",
3
  "alpha_contrastive_loss": 0.5,
4
  "architectures": [
5
  "BacformerForCausalGM"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
8
  "batch_size": 1,
9
  "ckpt_path": null,
10
  "dataloader_num_workers": 10,
@@ -15,212 +18,14 @@
15
  "hidden_dropout_prob": 0.1,
16
  "hidden_size": 480,
17
  "id2label": {
18
- "0": "LABEL_0",
19
- "1": "LABEL_1",
20
- "2": "LABEL_2",
21
- "3": "LABEL_3",
22
- "4": "LABEL_4",
23
- "5": "LABEL_5",
24
- "6": "LABEL_6",
25
- "7": "LABEL_7",
26
- "8": "LABEL_8",
27
- "9": "LABEL_9",
28
- "10": "LABEL_10",
29
- "11": "LABEL_11",
30
- "12": "LABEL_12",
31
- "13": "LABEL_13",
32
- "14": "LABEL_14",
33
- "15": "LABEL_15",
34
- "16": "LABEL_16",
35
- "17": "LABEL_17",
36
- "18": "LABEL_18",
37
- "19": "LABEL_19",
38
- "20": "LABEL_20",
39
- "21": "LABEL_21",
40
- "22": "LABEL_22",
41
- "23": "LABEL_23",
42
- "24": "LABEL_24",
43
- "25": "LABEL_25",
44
- "26": "LABEL_26",
45
- "27": "LABEL_27",
46
- "28": "LABEL_28",
47
- "29": "LABEL_29",
48
- "30": "LABEL_30",
49
- "31": "LABEL_31",
50
- "32": "LABEL_32",
51
- "33": "LABEL_33",
52
- "34": "LABEL_34",
53
- "35": "LABEL_35",
54
- "36": "LABEL_36",
55
- "37": "LABEL_37",
56
- "38": "LABEL_38",
57
- "39": "LABEL_39",
58
- "40": "LABEL_40",
59
- "41": "LABEL_41",
60
- "42": "LABEL_42",
61
- "43": "LABEL_43",
62
- "44": "LABEL_44",
63
- "45": "LABEL_45",
64
- "46": "LABEL_46",
65
- "47": "LABEL_47",
66
- "48": "LABEL_48",
67
- "49": "LABEL_49",
68
- "50": "LABEL_50",
69
- "51": "LABEL_51",
70
- "52": "LABEL_52",
71
- "53": "LABEL_53",
72
- "54": "LABEL_54",
73
- "55": "LABEL_55",
74
- "56": "LABEL_56",
75
- "57": "LABEL_57",
76
- "58": "LABEL_58",
77
- "59": "LABEL_59",
78
- "60": "LABEL_60",
79
- "61": "LABEL_61",
80
- "62": "LABEL_62",
81
- "63": "LABEL_63",
82
- "64": "LABEL_64",
83
- "65": "LABEL_65",
84
- "66": "LABEL_66",
85
- "67": "LABEL_67",
86
- "68": "LABEL_68",
87
- "69": "LABEL_69",
88
- "70": "LABEL_70",
89
- "71": "LABEL_71",
90
- "72": "LABEL_72",
91
- "73": "LABEL_73",
92
- "74": "LABEL_74",
93
- "75": "LABEL_75",
94
- "76": "LABEL_76",
95
- "77": "LABEL_77",
96
- "78": "LABEL_78",
97
- "79": "LABEL_79",
98
- "80": "LABEL_80",
99
- "81": "LABEL_81",
100
- "82": "LABEL_82",
101
- "83": "LABEL_83",
102
- "84": "LABEL_84",
103
- "85": "LABEL_85",
104
- "86": "LABEL_86",
105
- "87": "LABEL_87",
106
- "88": "LABEL_88",
107
- "89": "LABEL_89",
108
- "90": "LABEL_90",
109
- "91": "LABEL_91",
110
- "92": "LABEL_92",
111
- "93": "LABEL_93",
112
- "94": "LABEL_94",
113
- "95": "LABEL_95",
114
- "96": "LABEL_96",
115
- "97": "LABEL_97",
116
- "98": "LABEL_98",
117
- "99": "LABEL_99"
118
  },
119
  "initializer_range": 0.02,
120
  "input_dir": "/rds/user/mw896/rds-flotolab-9X9gY1OFt4M/projects/bacformer/input-data/eval-genomes/",
121
  "intermediate_size": 1280,
122
  "is_causal_gm": true,
123
  "label2id": {
124
- "LABEL_0": 0,
125
- "LABEL_1": 1,
126
- "LABEL_10": 10,
127
- "LABEL_11": 11,
128
- "LABEL_12": 12,
129
- "LABEL_13": 13,
130
- "LABEL_14": 14,
131
- "LABEL_15": 15,
132
- "LABEL_16": 16,
133
- "LABEL_17": 17,
134
- "LABEL_18": 18,
135
- "LABEL_19": 19,
136
- "LABEL_2": 2,
137
- "LABEL_20": 20,
138
- "LABEL_21": 21,
139
- "LABEL_22": 22,
140
- "LABEL_23": 23,
141
- "LABEL_24": 24,
142
- "LABEL_25": 25,
143
- "LABEL_26": 26,
144
- "LABEL_27": 27,
145
- "LABEL_28": 28,
146
- "LABEL_29": 29,
147
- "LABEL_3": 3,
148
- "LABEL_30": 30,
149
- "LABEL_31": 31,
150
- "LABEL_32": 32,
151
- "LABEL_33": 33,
152
- "LABEL_34": 34,
153
- "LABEL_35": 35,
154
- "LABEL_36": 36,
155
- "LABEL_37": 37,
156
- "LABEL_38": 38,
157
- "LABEL_39": 39,
158
- "LABEL_4": 4,
159
- "LABEL_40": 40,
160
- "LABEL_41": 41,
161
- "LABEL_42": 42,
162
- "LABEL_43": 43,
163
- "LABEL_44": 44,
164
- "LABEL_45": 45,
165
- "LABEL_46": 46,
166
- "LABEL_47": 47,
167
- "LABEL_48": 48,
168
- "LABEL_49": 49,
169
- "LABEL_5": 5,
170
- "LABEL_50": 50,
171
- "LABEL_51": 51,
172
- "LABEL_52": 52,
173
- "LABEL_53": 53,
174
- "LABEL_54": 54,
175
- "LABEL_55": 55,
176
- "LABEL_56": 56,
177
- "LABEL_57": 57,
178
- "LABEL_58": 58,
179
- "LABEL_59": 59,
180
- "LABEL_6": 6,
181
- "LABEL_60": 60,
182
- "LABEL_61": 61,
183
- "LABEL_62": 62,
184
- "LABEL_63": 63,
185
- "LABEL_64": 64,
186
- "LABEL_65": 65,
187
- "LABEL_66": 66,
188
- "LABEL_67": 67,
189
- "LABEL_68": 68,
190
- "LABEL_69": 69,
191
- "LABEL_7": 7,
192
- "LABEL_70": 70,
193
- "LABEL_71": 71,
194
- "LABEL_72": 72,
195
- "LABEL_73": 73,
196
- "LABEL_74": 74,
197
- "LABEL_75": 75,
198
- "LABEL_76": 76,
199
- "LABEL_77": 77,
200
- "LABEL_78": 78,
201
- "LABEL_79": 79,
202
- "LABEL_8": 8,
203
- "LABEL_80": 80,
204
- "LABEL_81": 81,
205
- "LABEL_82": 82,
206
- "LABEL_83": 83,
207
- "LABEL_84": 84,
208
- "LABEL_85": 85,
209
- "LABEL_86": 86,
210
- "LABEL_87": 87,
211
- "LABEL_88": 88,
212
- "LABEL_89": 89,
213
- "LABEL_9": 9,
214
- "LABEL_90": 90,
215
- "LABEL_91": 91,
216
- "LABEL_92": 92,
217
- "LABEL_93": 93,
218
- "LABEL_94": 94,
219
- "LABEL_95": 95,
220
- "LABEL_96": 96,
221
- "LABEL_97": 97,
222
- "LABEL_98": 98,
223
- "LABEL_99": 99
224
  },
225
  "layer_norm_eps": 1e-12,
226
  "logging_steps": 500,
@@ -262,7 +67,7 @@
262
  "test_after_train": false,
263
  "torch_dtype": "float32",
264
  "train_subset_prop": 1.0,
265
- "transformers_version": "4.38.2",
266
  "warmup_proportion": 0.1,
267
  "weight_decay": 0.01
268
  }
 
1
  {
 
2
  "alpha_contrastive_loss": 0.5,
3
  "architectures": [
4
  "BacformerForCausalGM"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_bacformer.BacformerConfig",
9
+ "AutoModelForCausalLM": "modeling_bacformer.BacformerForCausalGM"
10
+ },
11
  "batch_size": 1,
12
  "ckpt_path": null,
13
  "dataloader_num_workers": 10,
 
18
  "hidden_dropout_prob": 0.1,
19
  "hidden_size": 480,
20
  "id2label": {
21
+ "0": "LABEL_0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  },
23
  "initializer_range": 0.02,
24
  "input_dir": "/rds/user/mw896/rds-flotolab-9X9gY1OFt4M/projects/bacformer/input-data/eval-genomes/",
25
  "intermediate_size": 1280,
26
  "is_causal_gm": true,
27
  "label2id": {
28
+ "LABEL_0": 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  },
30
  "layer_norm_eps": 1e-12,
31
  "logging_steps": 500,
 
67
  "test_after_train": false,
68
  "torch_dtype": "float32",
69
  "train_subset_prop": 1.0,
70
+ "transformers_version": "4.50.3",
71
  "warmup_proportion": 0.1,
72
  "weight_decay": 0.01
73
  }
configuration_bacformer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ SPECIAL_TOKENS_DICT = {
6
+ "PAD": 0,
7
+ "MASK": 1,
8
+ "CLS": 2,
9
+ "SEP": 3,
10
+ "PROT_EMB": 4,
11
+ "END": 5,
12
+ }
13
+
14
+
15
+ class BacformerConfig(PretrainedConfig):
16
+ """Configuration class to store the configuration of a `BacformerModel`."""
17
+
18
+ model_type = "bacformer"
19
+
20
+ def __init__(
21
+ self,
22
+ num_hidden_layers: int = 6,
23
+ num_attention_heads: int = 8,
24
+ hidden_size: int = 480, # default esm2_t12_35M_UR50D embedding dim
25
+ intermediate_size: int = 1280,
26
+ hidden_dropout_prob: float = 0.1,
27
+ attention_probs_dropout_prob: float = 0.1,
28
+ max_position_embeddings: int = 6000,
29
+ max_token_type_embeddings: int = 1000,
30
+ layer_norm_eps: float = 1e-12,
31
+ initializer_range: float = 0.02,
32
+ pad_token_id: int = SPECIAL_TOKENS_DICT["PAD"],
33
+ mask_token_id: int = SPECIAL_TOKENS_DICT["MASK"],
34
+ prot_emb_token_id: int = SPECIAL_TOKENS_DICT["PROT_EMB"],
35
+ end_token_id: int = SPECIAL_TOKENS_DICT["END"],
36
+ num_special_tokens: int = len(SPECIAL_TOKENS_DICT),
37
+ protein_clusters_vocab_size: int = 50001, # equal to the nr of protein clusters + 1
38
+ num_labels: int = 1, # for downstream tasks
39
+ is_causal_gm: bool = False,
40
+ return_dict: bool = False,
41
+ return_attn_weights: bool = False,
42
+ alpha_contrastive_loss: float = 0.5,
43
+ # only to use in the BacformerForGenomeClassification
44
+ problem_type: Literal[
45
+ "regression", "binary_classification", "single_label_classification", "multi_label_classification"
46
+ ] = "single_label_classification",
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.hidden_size = hidden_size
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_dropout_prob = hidden_dropout_prob
56
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
57
+ self.max_position_embeddings = max_position_embeddings
58
+ self.max_token_type_embeddings = max_token_type_embeddings
59
+ self.layer_norm_eps = layer_norm_eps
60
+ self.initializer_range = initializer_range
61
+ self.pad_token_id = pad_token_id
62
+ self.mask_token_id = mask_token_id
63
+ self.prot_emb_token_id = prot_emb_token_id
64
+ self.end_token_id = end_token_id
65
+ self.num_special_tokens = num_special_tokens
66
+ self.protein_clusters_vocab_size = protein_clusters_vocab_size
67
+ self.num_labels = num_labels
68
+ self.is_causal_gm = is_causal_gm
69
+ self.return_dict = return_dict
70
+ self.return_attn_weights = return_attn_weights
71
+ self.problem_type = problem_type
72
+ self.alpha_contrastive_loss = alpha_contrastive_loss
modeling_bacformer.py ADDED
@@ -0,0 +1,1461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Optional, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn.functional import (
9
+ binary_cross_entropy_with_logits,
10
+ cross_entropy,
11
+ gelu,
12
+ mse_loss,
13
+ scaled_dot_product_attention,
14
+ softmax,
15
+ )
16
+ from transformers import PreTrainedModel
17
+ from transformers.utils import ModelOutput
18
+
19
+ from bacformer_model.configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig
20
+
21
+
22
+ def compute_contrastive_loss(
23
+ protein_embeddings: torch.Tensor,
24
+ last_hidden_state: torch.Tensor,
25
+ special_tokens_mask: torch.Tensor,
26
+ ) -> torch.Tensor:
27
+ """Compute contrastive loss between protein embeddings and masked items."""
28
+ # keep protein embeddings and masked items
29
+ # ensure the batch size is 1, the model currently does not work with batch size > 1
30
+ assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1
31
+
32
+ # subset to mask and protein embedding tokens
33
+ special_tokens_mask = special_tokens_mask.squeeze(0)
34
+ mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
35
+ special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
36
+ )
37
+ protein_embeddings = protein_embeddings.squeeze(0)[mask]
38
+ last_hidden_state = last_hidden_state.squeeze(0)[mask]
39
+
40
+ # Normalize embeddings
41
+ last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
42
+ protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)
43
+
44
+ # Compute similarity matrix and loss as before
45
+ similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)
46
+
47
+ n_prots = protein_embeddings.shape[0]
48
+ labels = torch.arange(n_prots).to(protein_embeddings.device)
49
+
50
+ # Compute the loss
51
+ loss = cross_entropy(similarity_matrix, labels)
52
+ return loss
53
+
54
+
55
+ def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
56
+ """
57
+ Keep only top_k logits and set the rest to -inf.
58
+
59
+ Args:
60
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
61
+ top_k (int): The number of highest probability logits to keep.
62
+
63
+ Returns
64
+ -------
65
+ torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
66
+ """
67
+ if top_k <= 0:
68
+ return logits
69
+
70
+ # Find top_k values
71
+ top_k = min(top_k, logits.size(-1))
72
+ vals, idx = torch.topk(logits, top_k, dim=-1)
73
+ # Get the smallest logit in the top_k
74
+ min_vals = vals[:, -1].unsqueeze(-1)
75
+ # Mask all logits that are < this min value
76
+ mask = logits < min_vals
77
+ logits[mask] = float("-inf")
78
+ return logits
79
+
80
+
81
+ def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
82
+ """
83
+ Keep the smallest set of logits whose cumulative probability >= top_p.
84
+
85
+ Args:
86
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
87
+ top_p (float): Cumulative probability threshold.
88
+
89
+ Returns
90
+ -------
91
+ torch.Tensor: Filtered logits where only tokens within the top_p cumulative
92
+ probability mass are kept; the rest are set to -inf.
93
+ """
94
+ if top_p >= 1.0:
95
+ return logits
96
+
97
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
98
+ cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
99
+
100
+ # Identify where cumulative probability exceeds top_p
101
+ sorted_indices_to_remove = cumulative_probs > top_p
102
+ # Shift the mask to ensure we always keep at least one token
103
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
104
+ sorted_indices_to_remove[..., 0] = False
105
+
106
+ # Scatter to replicate the mask in the original ordering
107
+ for i in range(logits.size(0)):
108
+ remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
109
+ logits[i, remove_indices] = float("-inf")
110
+
111
+ return logits
112
+
113
+
114
+ def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
115
+ """Helper function to reshape attn_mask to 3D from 2D"""
116
+ assert (
117
+ len(attn_mask.shape) == 2
118
+ ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"
119
+
120
+ bs, seq_len = attn_mask.shape
121
+ attn_mask = attn_mask.view(bs, 1, 1, seq_len)
122
+ attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
123
+ attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
124
+ return attn_mask
125
+
126
+
127
+ @dataclass
128
+ class BacformerModelOutput(ModelOutput):
129
+ """Base class for outputs of the Bacformer model."""
130
+
131
+ loss: torch.FloatTensor | None = None
132
+ logits: torch.FloatTensor = None
133
+ last_hidden_state: torch.FloatTensor | None = None
134
+ attentions: Union[torch.FloatTensor, None] = None
135
+ pooler_output: torch.FloatTensor | None = None
136
+
137
+
138
+ # Taken from facebookresearch/llama/model.py
139
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
140
+ """Reshape the rotary embeddings for broadcasting."""
141
+ ndim = x.ndim
142
+ assert 0 <= 1 < ndim
143
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
144
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
145
+ return freqs_cis.view(*shape)
146
+
147
+
148
+ # Taken from facebookresearch/llama/model.py
149
+ def apply_rotary_emb(
150
+ xq: torch.Tensor,
151
+ xk: torch.Tensor,
152
+ freqs_cos: torch.Tensor,
153
+ freqs_sin: torch.Tensor,
154
+ ) -> tuple[torch.Tensor, torch.Tensor]:
155
+ """Apply rotary embeddings to the query and key tensors."""
156
+ # reshape xq and xk to match the complex representation
157
+ xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
158
+ xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
159
+
160
+ # reshape freqs_cos and freqs_sin for broadcasting
161
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
162
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
163
+
164
+ # apply rotation using real numbers
165
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
166
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
167
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
168
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
169
+
170
+ # flatten last two dimensions
171
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
172
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
173
+
174
+ return xq_out.type_as(xq), xk_out.type_as(xk)
175
+
176
+
177
+ # Taken from facebookresearch/llama/model.py
178
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
179
+ """Precompute the freqs cis for rotary embeddings."""
180
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
181
+ t = torch.arange(end, device=freqs.device) # type: ignore
182
+ freqs = torch.outer(t, freqs).float() # type: ignore
183
+
184
+ freqs_cos = torch.cos(freqs) # real part
185
+ freqs_sin = torch.sin(freqs) # imaginary part
186
+ return freqs_cos, freqs_sin
187
+
188
+
189
+ def symmetrize(x):
190
+ """Make layer symmetric in final two dimensions, used for protein-protein interaction prediction."""
191
+ return x + x.transpose(-1, -2)
192
+
193
+
194
+ def average_product_correct(x):
195
+ """Perform average product correct, used for protein-protein interaction prediction."""
196
+ a1 = x.sum(-1, keepdims=True)
197
+ a2 = x.sum(-2, keepdims=True)
198
+ a12 = x.sum((-1, -2), keepdims=True)
199
+
200
+ avg = a1 * a2
201
+ # avg.div_(a12) # in-place to reduce memory
202
+ normalized = x - avg.div_(a12)
203
+ return normalized
204
+
205
+
206
+ def scaled_dot_product_attention_w_attn_weights(
207
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
208
+ ) -> tuple[torch.Tensor, torch.Tensor]:
209
+ """PyTorch Native implementation, modified to return attention weights."""
210
+ L, S = query.size(-2), key.size(-2)
211
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
212
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
213
+ if is_causal:
214
+ assert attn_mask is None
215
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
216
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
217
+ attn_bias.to(query.dtype)
218
+
219
+ if attn_mask is not None:
220
+ if attn_mask.dtype == torch.bool:
221
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
222
+ else:
223
+ attn_bias += attn_mask
224
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
225
+ attn_weight += attn_bias
226
+ attn_weight = torch.softmax(attn_weight, dim=-1)
227
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
228
+ attn_output = attn_weight @ value
229
+ return attn_output, attn_weight
230
+
231
+
232
+ class RotarySelfAttention(nn.Module):
233
+ """Rotary self-attention module."""
234
+
235
+ def __init__(
236
+ self,
237
+ embed_dim: int,
238
+ num_heads: int,
239
+ dropout: float = 0.1,
240
+ ):
241
+ super().__init__()
242
+ self.embed_dim = embed_dim
243
+ self.num_heads = num_heads
244
+ self.dim_head = embed_dim // num_heads
245
+ self.dropout_rate = dropout
246
+
247
+ self.q = nn.Linear(embed_dim, embed_dim, bias=False)
248
+ self.k = nn.Linear(embed_dim, embed_dim, bias=False)
249
+ self.v = nn.Linear(embed_dim, embed_dim, bias=False)
250
+ self.att_proj_linear = nn.Linear(embed_dim, embed_dim)
251
+
252
+ def forward(
253
+ self,
254
+ x: torch.Tensor,
255
+ attn_mask: torch.Tensor,
256
+ freqs_cos: torch.Tensor,
257
+ freqs_sin: torch.Tensor,
258
+ is_causal: bool = False,
259
+ return_attn_weights: bool = False,
260
+ ):
261
+ """Forward pass for the rotary self-attention module."""
262
+ batch_size, seq_len, _ = x.shape
263
+ xq, xk, xv = self.q(x), self.k(x), self.v(x)
264
+ # Reshape for rotary embeddings
265
+ xq = xq.view(batch_size, seq_len, self.num_heads, self.dim_head)
266
+ xk = xk.view(batch_size, seq_len, self.num_heads, self.dim_head)
267
+ xv = xv.view(batch_size, seq_len, self.num_heads, self.dim_head)
268
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
269
+
270
+ # Reshape for attention calculation: (b_sz, n_head, s_len, d_head)
271
+ xq = xq.transpose(1, 2)
272
+ xk = xk.transpose(1, 2)
273
+ xv = xv.transpose(1, 2)
274
+
275
+ attn_weights = None
276
+ if return_attn_weights:
277
+ att, attn_weights = scaled_dot_product_attention_w_attn_weights(
278
+ query=xq,
279
+ key=xk,
280
+ value=xv,
281
+ attn_mask=attn_mask,
282
+ dropout_p=self.dropout_rate if self.training else 0.0,
283
+ is_causal=is_causal,
284
+ )
285
+ else:
286
+ att = scaled_dot_product_attention(
287
+ query=xq,
288
+ key=xk,
289
+ value=xv,
290
+ attn_mask=attn_mask,
291
+ dropout_p=self.dropout_rate if self.training else 0.0,
292
+ is_causal=is_causal,
293
+ )
294
+ # Shape (b_sz, s_len, n_head, d_head)
295
+ out = att.transpose(1, 2).contiguous()
296
+ out = out.view(batch_size, seq_len, self.num_heads * self.dim_head)
297
+
298
+ return self.att_proj_linear(out), attn_weights
299
+
300
+
301
+ class BacformerTransformerLayer(nn.Module):
302
+ """Own implementation of transformer layer which uses pytorch native MHA but returns attention weights"""
303
+
304
+ def __init__(
305
+ self,
306
+ hidden_size: int,
307
+ intermediate_size: int,
308
+ num_attention_heads: int,
309
+ dropout: float = 0.1,
310
+ activation: Literal["gelu", "relu"] = "gelu",
311
+ ):
312
+ super().__init__()
313
+ self.self_mha = RotarySelfAttention(
314
+ embed_dim=hidden_size,
315
+ num_heads=num_attention_heads,
316
+ dropout=dropout,
317
+ )
318
+
319
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
320
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
321
+ self.activation = nn.GELU() if activation == "gelu" else nn.ReLU()
322
+ self.norm1 = nn.LayerNorm(hidden_size)
323
+ self.norm2 = nn.LayerNorm(hidden_size)
324
+ self.dropout1 = nn.Dropout(dropout)
325
+ self.dropout2 = nn.Dropout(dropout)
326
+ self.dropout3 = nn.Dropout(dropout)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_state: torch.Tensor,
331
+ attention_mask: torch.Tensor = None,
332
+ freqs_cos: torch.Tensor = None,
333
+ freqs_sin: torch.Tensor = None,
334
+ return_attn_weights: bool = False,
335
+ is_causal: bool = False,
336
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
337
+ """Forward pass"""
338
+ attn_outputs, attn_weights = self.self_mha(
339
+ hidden_state,
340
+ attn_mask=attention_mask,
341
+ freqs_cos=freqs_cos,
342
+ freqs_sin=freqs_sin,
343
+ return_attn_weights=return_attn_weights,
344
+ is_causal=is_causal,
345
+ )
346
+ x = self.norm1(hidden_state + self.dropout1(attn_outputs))
347
+ ff_output = self.fc2(self.dropout2(self.activation(self.fc1(x))))
348
+ x = self.norm2(x + self.dropout3(ff_output))
349
+ return x, attn_weights
350
+
351
+
352
+ class BacformerTransformerEncoder(nn.Module):
353
+ """Own implementation of Transformer which return attention weights"""
354
+
355
+ def __init__(
356
+ self,
357
+ num_hidden_layers: int,
358
+ hidden_size: int,
359
+ intermediate_size: int,
360
+ num_attention_heads: int,
361
+ dropout: float = 0.1,
362
+ activation: Literal["gelu", "relu"] = "gelu",
363
+ ):
364
+ super().__init__()
365
+
366
+ self.layers = nn.ModuleList(
367
+ [
368
+ BacformerTransformerLayer(
369
+ hidden_size=hidden_size,
370
+ intermediate_size=intermediate_size,
371
+ num_attention_heads=num_attention_heads,
372
+ dropout=dropout,
373
+ activation=activation,
374
+ )
375
+ for _ in range(num_hidden_layers)
376
+ ]
377
+ )
378
+ self.gradient_checkpointing = False
379
+
380
+ def forward(
381
+ self,
382
+ hidden_state: torch.Tensor,
383
+ attention_mask: torch.Tensor = None,
384
+ freqs_cos: torch.Tensor = None,
385
+ freqs_sin: torch.Tensor = None,
386
+ return_attn_weights: bool = False,
387
+ is_causal: bool = False,
388
+ ) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
389
+ """Forward pass"""
390
+ attn_weights_arr = []
391
+ for layer in self.layers:
392
+ if self.gradient_checkpointing and self.training:
393
+ hidden_state, attn_weights = self._gradient_checkpointing_func(
394
+ layer.__call__,
395
+ hidden_state,
396
+ attention_mask,
397
+ freqs_cos,
398
+ freqs_sin,
399
+ return_attn_weights,
400
+ is_causal,
401
+ )
402
+ else:
403
+ hidden_state, attn_weights = layer(
404
+ hidden_state=hidden_state,
405
+ attention_mask=attention_mask,
406
+ freqs_cos=freqs_cos,
407
+ freqs_sin=freqs_sin,
408
+ return_attn_weights=return_attn_weights,
409
+ is_causal=is_causal,
410
+ )
411
+ # keep the attention weights from each layer
412
+ attn_weights_arr.append(attn_weights)
413
+ return hidden_state, attn_weights_arr
414
+
415
+
416
+ class BacformerEmbeddings(nn.Module):
417
+ """Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings."""
418
+
419
+ def __init__(self, config: BacformerConfig):
420
+ super().__init__()
421
+ self.config = config
422
+ self.linear = nn.Linear(config.hidden_size, config.hidden_size)
423
+
424
+ self.token_type_embeddings = nn.Embedding(
425
+ num_embeddings=config.max_token_type_embeddings + 1,
426
+ embedding_dim=config.hidden_size,
427
+ padding_idx=config.max_token_type_embeddings,
428
+ )
429
+
430
+ self.special_tokens_embeddings = nn.Embedding(
431
+ num_embeddings=config.num_special_tokens,
432
+ embedding_dim=config.hidden_size,
433
+ )
434
+ self.prot_emb_token_id = config.prot_emb_token_id
435
+ self.pad_token_id = config.pad_token_id
436
+
437
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
438
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
439
+
440
+ def forward(
441
+ self,
442
+ protein_embeddings: torch.Tensor = None,
443
+ special_tokens_mask: torch.Tensor = None,
444
+ token_type_ids: torch.Tensor = None,
445
+ labels: torch.Tensor = None, # used for causal protein family modeling
446
+ property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property
447
+ ) -> torch.Tensor:
448
+ """Forward pass for protein embeddings."""
449
+ bs, seq_length, dim = protein_embeddings.shape
450
+
451
+ # pass the pooled ESM protein embeddings through a linear layer
452
+ protein_embeddings = self.linear(protein_embeddings)
453
+ protein_embeddings = torch.where(
454
+ special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
455
+ protein_embeddings,
456
+ self.special_tokens_embeddings(special_tokens_mask),
457
+ )
458
+
459
+ if token_type_ids is not None:
460
+ protein_embeddings += self.token_type_embeddings(token_type_ids)
461
+
462
+ protein_embeddings = self.LayerNorm(protein_embeddings)
463
+ protein_embeddings = self.dropout(protein_embeddings)
464
+ return protein_embeddings
465
+
466
+
467
+ class BacformerProteinFamilyEmbeddings(nn.Module):
468
+ """Construct the protein embeddings from protein family tokens, special tokens and sequence type embeddings."""
469
+
470
+ def __init__(
471
+ self,
472
+ config: BacformerConfig,
473
+ protein_family_embeddings: torch.Tensor = None,
474
+ token_type_embeddings: torch.Tensor = None,
475
+ special_tokens_embeddings: torch.Tensor = None,
476
+ n_conditional_properties: int = None,
477
+ ):
478
+ super().__init__()
479
+ self.config = config
480
+
481
+ if protein_family_embeddings is not None:
482
+ self.protein_family_embeddings = nn.Embedding.from_pretrained(
483
+ protein_family_embeddings,
484
+ freeze=False,
485
+ padding_idx=config.pad_token_id,
486
+ )
487
+ else:
488
+ self.protein_family_embeddings = nn.Embedding(
489
+ num_embeddings=config.protein_clusters_vocab_size + 1,
490
+ embedding_dim=config.hidden_size,
491
+ padding_idx=config.pad_token_id,
492
+ )
493
+
494
+ if token_type_embeddings is not None:
495
+ self.token_type_embeddings = nn.Embedding.from_pretrained(
496
+ token_type_embeddings,
497
+ freeze=False,
498
+ padding_idx=config.max_token_type_embeddings,
499
+ )
500
+ else:
501
+ self.token_type_embeddings = nn.Embedding(
502
+ num_embeddings=config.max_token_type_embeddings + 1,
503
+ embedding_dim=config.hidden_size,
504
+ padding_idx=config.max_token_type_embeddings,
505
+ )
506
+
507
+ if special_tokens_embeddings is not None:
508
+ self.special_tokens_embeddings = nn.Embedding.from_pretrained(
509
+ special_tokens_embeddings,
510
+ freeze=False,
511
+ padding_idx=config.pad_token_id,
512
+ )
513
+ else:
514
+ self.special_tokens_embeddings = nn.Embedding(
515
+ num_embeddings=config.num_special_tokens,
516
+ embedding_dim=config.hidden_size,
517
+ padding_idx=config.pad_token_id,
518
+ )
519
+
520
+ # add layer for conditional properties
521
+ if n_conditional_properties is not None:
522
+ self.conditional_properties_layer = nn.Embedding(n_conditional_properties, config.hidden_size)
523
+
524
+ self.prot_emb_token_id = config.prot_emb_token_id
525
+ self.pad_token_id = config.pad_token_id
526
+
527
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
528
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
529
+
530
+ def forward(
531
+ self,
532
+ protein_embeddings: torch.Tensor = None,
533
+ special_tokens_mask: torch.Tensor = None,
534
+ token_type_ids: torch.Tensor = None,
535
+ labels: torch.Tensor = None, # used for causal protein family modeling
536
+ property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property
537
+ ) -> torch.Tensor:
538
+ """Forward pass for protein embeddings."""
539
+ # pass the pooled ESM protein embeddings through a linear layer
540
+ # replace -100 with pad_token_id
541
+ labels[labels == -100] = self.pad_token_id
542
+ protein_embeddings = self.protein_family_embeddings(labels)
543
+
544
+ bs, seq_length, dim = protein_embeddings.shape
545
+ protein_embeddings = torch.where(
546
+ special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
547
+ protein_embeddings,
548
+ self.special_tokens_embeddings(special_tokens_mask),
549
+ )
550
+
551
+ if token_type_ids is not None:
552
+ protein_embeddings += self.token_type_embeddings(token_type_ids)
553
+
554
+ if property_ids is not None:
555
+ # get the embeddings for the conditional properties
556
+ property_embedding = self.conditional_properties_layer(property_ids).unsqueeze(1)
557
+ # concatenate the protein embeddings with the conditional properties embeddings
558
+ # property embeddings are added to the beginning of the protein embeddings after the CLS token
559
+ protein_embeddings = torch.cat(
560
+ [
561
+ protein_embeddings[:, :1, :], # CLS token
562
+ property_embedding, # conditional properties embeddings
563
+ protein_embeddings[:, 1:, :],
564
+ ], # protein embeddings
565
+ dim=1,
566
+ )
567
+
568
+ protein_embeddings = self.LayerNorm(protein_embeddings)
569
+ protein_embeddings = self.dropout(protein_embeddings)
570
+ return protein_embeddings
571
+
572
+
573
+ class BacformerEncoder(nn.Module):
574
+ """Bacformer encoder model"""
575
+
576
+ def __init__(self, config: BacformerConfig):
577
+ super().__init__()
578
+ self.config = config
579
+
580
+ self.encoder = BacformerTransformerEncoder(
581
+ num_hidden_layers=config.num_hidden_layers,
582
+ hidden_size=config.hidden_size,
583
+ num_attention_heads=config.num_attention_heads,
584
+ intermediate_size=config.intermediate_size,
585
+ activation="gelu",
586
+ dropout=config.attention_probs_dropout_prob,
587
+ )
588
+
589
+ # Note that config.max_position_embeddings is multiplied by 1.5 because the token limit for the Bacformer of
590
+ # models is 6000. Adding this multiplier instead of using 6000 directly allows for dynamism of token
591
+ # lengths while training or fine-tuning.
592
+ freqs_cos, freqs_sin = precompute_freqs_cis(
593
+ config.hidden_size // config.num_attention_heads, int(config.max_position_embeddings * 1.5)
594
+ )
595
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
596
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
597
+
598
+ def forward(
599
+ self,
600
+ hidden_states: torch.Tensor,
601
+ attention_mask: torch.Tensor = None,
602
+ return_attn_weights: Union[bool, None] = None,
603
+ is_causal: bool = False,
604
+ ) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
605
+ """Pass the input through the encoder layers in turn.
606
+
607
+ Args:
608
+ hidden_states: hidden states from the BacformerEmbeddings layer
609
+ attention_mask: mask for the attention in the transformer
610
+ """
611
+ return_attn_weights = (
612
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
613
+ )
614
+ bs, seq_len, _ = hidden_states.shape
615
+ last_hidden_state, attn_weights = self.encoder(
616
+ hidden_state=hidden_states,
617
+ attention_mask=attention_mask,
618
+ freqs_cos=self.freqs_cos[:seq_len, :],
619
+ freqs_sin=self.freqs_sin[:seq_len, :],
620
+ return_attn_weights=return_attn_weights,
621
+ is_causal=is_causal,
622
+ )
623
+ return last_hidden_state, attn_weights
624
+
625
+
626
+ class BacformerPreTrainedModel(PreTrainedModel):
627
+ """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models."""
628
+
629
+ config_class = BacformerConfig
630
+ base_model_prefix = "bacformer"
631
+ supports_gradient_checkpointing = True
632
+ _no_split_modules = ["BacformerEmbeddings", "BacformerTransformerLayer"]
633
+
634
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
635
+ def _init_weights(self, module):
636
+ """Initialize the weights"""
637
+ if isinstance(module, nn.Linear):
638
+ # Slightly different from the TF version which uses truncated_normal for initialization
639
+ # cf https://github.com/pytorch/pytorch/pull/5617
640
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
641
+ if module.bias is not None:
642
+ module.bias.data.zero_()
643
+ elif isinstance(module, nn.Embedding):
644
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
645
+ if module.padding_idx is not None:
646
+ module.weight.data[module.padding_idx].zero_()
647
+ elif isinstance(module, nn.LayerNorm):
648
+ module.bias.data.zero_()
649
+ module.weight.data.fill_(1.0)
650
+
651
+
652
+ class BacformerModel(BacformerPreTrainedModel):
653
+ """Bacformer model."""
654
+
655
+ def __init__(self, config: BacformerConfig, add_pooling_layer: bool = False):
656
+ super().__init__(config)
657
+ self.config = config
658
+
659
+ self.embeddings = BacformerEmbeddings(config)
660
+ self.encoder = BacformerEncoder(config)
661
+
662
+ self.pooler = BacformerPooler(config) if add_pooling_layer else None
663
+
664
+ # Initialize weights and apply final processing
665
+ self.post_init()
666
+
667
+ def forward(
668
+ self,
669
+ protein_embeddings: torch.Tensor = None,
670
+ special_tokens_mask: torch.Tensor = None,
671
+ token_type_ids: torch.Tensor = None,
672
+ attention_mask: torch.Tensor = None,
673
+ labels: torch.Tensor = None,
674
+ property_ids: torch.Tensor = None,
675
+ return_attn_weights: bool = False,
676
+ return_dict: Union[bool, None] = None,
677
+ is_causal: bool = False,
678
+ ) -> Optional[BacformerModelOutput]:
679
+ """Forward method for the model."""
680
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
681
+ # get embeddings
682
+ protein_embeddings = self.embeddings(
683
+ protein_embeddings=protein_embeddings,
684
+ labels=labels,
685
+ special_tokens_mask=special_tokens_mask,
686
+ token_type_ids=token_type_ids,
687
+ property_ids=property_ids,
688
+ )
689
+
690
+ # create 3D attention mask from 2D if not doing causal GM
691
+ if attention_mask is not None and not is_causal:
692
+ attention_mask = create_4d_from_2d_attn_mask(
693
+ attn_mask=attention_mask, num_attn_heads=self.config.num_attention_heads
694
+ ).bool()
695
+
696
+ last_hidden_state, attentions = self.encoder(
697
+ hidden_states=protein_embeddings,
698
+ attention_mask=attention_mask,
699
+ return_attn_weights=return_attn_weights,
700
+ is_causal=is_causal,
701
+ )
702
+ pooler_output = (
703
+ self.pooler(hidden_states=last_hidden_state, padding_mask=attention_mask)
704
+ if self.pooler is not None
705
+ else None
706
+ )
707
+
708
+ if not return_dict:
709
+ return (last_hidden_state, pooler_output, attentions)
710
+
711
+ return BacformerModelOutput(
712
+ last_hidden_state=last_hidden_state,
713
+ pooler_output=pooler_output,
714
+ attentions=attentions,
715
+ )
716
+
717
+
718
+ class BacformerForCausalGM(BacformerPreTrainedModel):
719
+ """Bacformer model with genomic modeling head on top"""
720
+
721
+ _tied_weights_keys = ["gm_head.decoder.weight"]
722
+
723
+ def __init__(self, config: BacformerConfig):
724
+ super().__init__(config)
725
+ self.config = config
726
+
727
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
728
+ self.gm_head = BacformerGMHead(config)
729
+
730
+ # Initialize weights
731
+ self.init_weights()
732
+
733
+ def forward(
734
+ self,
735
+ protein_embeddings: torch.Tensor,
736
+ special_tokens_mask: torch.Tensor,
737
+ labels: torch.Tensor = None,
738
+ token_type_ids: torch.Tensor = None,
739
+ attention_mask: torch.Tensor = None,
740
+ return_attn_weights: bool = None,
741
+ return_dict: Union[bool, None] = None,
742
+ ) -> Optional[BacformerModelOutput]:
743
+ """Forward method for the model."""
744
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
745
+ return_attn_weights = (
746
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
747
+ )
748
+
749
+ outputs = self.bacformer(
750
+ protein_embeddings=protein_embeddings,
751
+ special_tokens_mask=special_tokens_mask,
752
+ token_type_ids=token_type_ids,
753
+ attention_mask=None, # attention mechanism handles the causal mask
754
+ return_attn_weights=return_attn_weights,
755
+ return_dict=return_dict,
756
+ is_causal=True,
757
+ )
758
+ last_hidden_state = outputs[0]
759
+ prediction_scores = self.gm_head(last_hidden_state)
760
+
761
+ loss = None
762
+ if labels is not None:
763
+ labels = labels.to(prediction_scores.device)
764
+
765
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1])
766
+ labels = labels[:, 1:].contiguous().view(-1)
767
+ loss = cross_entropy(shifted_prediction_scores, labels)
768
+
769
+ if not return_dict:
770
+ return (
771
+ loss,
772
+ prediction_scores,
773
+ ) + outputs
774
+
775
+ return BacformerModelOutput(
776
+ loss=loss,
777
+ logits=prediction_scores,
778
+ last_hidden_state=outputs.last_hidden_state,
779
+ attentions=outputs.attentions,
780
+ )
781
+
782
+
783
+ class BacformerForMaskedGM(BacformerPreTrainedModel):
784
+ """Bacformer model with genomic modeling head on top"""
785
+
786
+ _tied_weights_keys = ["gm_head.decoder.weight"]
787
+
788
+ def __init__(self, config: BacformerConfig):
789
+ super().__init__(config)
790
+ self.config = config
791
+
792
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
793
+ self.gm_head = BacformerGMHead(config)
794
+
795
+ # Initialize weights
796
+ self.init_weights()
797
+
798
+ def forward(
799
+ self,
800
+ protein_embeddings: torch.Tensor,
801
+ special_tokens_mask: torch.Tensor,
802
+ labels: torch.Tensor = None,
803
+ token_type_ids: torch.Tensor = None,
804
+ attention_mask: torch.Tensor = None,
805
+ return_attn_weights: bool = None,
806
+ return_dict: Union[bool, None] = None,
807
+ ) -> Union[BacformerModelOutput, None]:
808
+ """Forward method for the model."""
809
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
810
+ return_attn_weights = (
811
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
812
+ )
813
+
814
+ outputs = self.bacformer(
815
+ protein_embeddings=protein_embeddings,
816
+ special_tokens_mask=special_tokens_mask,
817
+ token_type_ids=token_type_ids,
818
+ attention_mask=attention_mask,
819
+ return_attn_weights=return_attn_weights,
820
+ return_dict=return_dict,
821
+ )
822
+ last_hidden_state = outputs[0]
823
+
824
+ # to speed up the forward pass, let's only consider the masked tokens
825
+
826
+ loss = None
827
+ if labels is not None:
828
+ # to speed up the forward pass, let's only consider the masked tokens
829
+ last_hidden_state = last_hidden_state[labels != -100]
830
+ prediction_scores = self.gm_head(last_hidden_state)
831
+ labels = labels.to(prediction_scores.device)
832
+
833
+ ### notes
834
+ # use the labels to get -100 for non-masked tokens
835
+ # do not use special_tokens_mask
836
+ # check how the labels are constructed
837
+
838
+ # only considering the masked tokens
839
+ labels = labels[labels != -100]
840
+ loss = cross_entropy(prediction_scores, labels)
841
+ else:
842
+ prediction_scores = self.gm_head(last_hidden_state)
843
+
844
+ if not return_dict:
845
+ return (
846
+ loss,
847
+ prediction_scores,
848
+ ) + outputs
849
+
850
+ return BacformerModelOutput(
851
+ loss=loss,
852
+ logits=prediction_scores,
853
+ last_hidden_state=outputs.last_hidden_state,
854
+ attentions=outputs.attentions,
855
+ )
856
+
857
+
858
+ class BacformerForCausalProteinFamilyModeling(BacformerPreTrainedModel):
859
+ """Bacformer model for causal modeling of protein families. Using protein family as tokens rather than protein embeddings"""
860
+
861
+ _tied_weights_keys = ["gm_head.decoder.weight"]
862
+
863
+ def __init__(
864
+ self,
865
+ config: BacformerConfig,
866
+ n_conditional_properties: int = None,
867
+ initialise_from_non_pfm_model: bool = False,
868
+ ):
869
+ super().__init__(config)
870
+ self.config = config
871
+ self.cls_token_id = SPECIAL_TOKENS_DICT["CLS"]
872
+
873
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
874
+ self.gm_head = BacformerGMHead(config)
875
+
876
+ if initialise_from_non_pfm_model:
877
+ # Initialize weights
878
+ self.init_weights()
879
+ # overwrite the embeddings with the pretrained
880
+ # protein family embeddings from the decoder of the GM Head
881
+ self.bacformer.embeddings = BacformerProteinFamilyEmbeddings(
882
+ config,
883
+ protein_family_embeddings=self.gm_head.decoder.weight,
884
+ token_type_embeddings=self.bacformer.embeddings.token_type_embeddings.weight,
885
+ special_tokens_embeddings=self.bacformer.embeddings.special_tokens_embeddings.weight,
886
+ n_conditional_properties=n_conditional_properties,
887
+ )
888
+ else:
889
+ self.bacformer.embeddings = BacformerProteinFamilyEmbeddings(
890
+ config,
891
+ n_conditional_properties=n_conditional_properties,
892
+ )
893
+ self.init_weights()
894
+
895
+ def forward(
896
+ self,
897
+ labels: torch.Tensor = None,
898
+ special_tokens_mask: torch.Tensor = None,
899
+ token_type_ids: torch.Tensor = None,
900
+ property_ids: torch.Tensor = None,
901
+ return_attn_weights: bool = None,
902
+ return_dict: Union[bool, None] = None,
903
+ ) -> Optional[BacformerModelOutput]:
904
+ """Forward method for the model."""
905
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
906
+ return_attn_weights = (
907
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
908
+ )
909
+
910
+ outputs = self.bacformer(
911
+ protein_embeddings=None,
912
+ labels=labels,
913
+ special_tokens_mask=special_tokens_mask,
914
+ token_type_ids=token_type_ids,
915
+ property_ids=property_ids,
916
+ return_attn_weights=return_attn_weights,
917
+ return_dict=return_dict,
918
+ is_causal=True,
919
+ )
920
+ last_hidden_state = outputs[0]
921
+ prediction_scores = self.gm_head(last_hidden_state)
922
+
923
+ loss = None
924
+ if labels is not None:
925
+ if property_ids is not None:
926
+ labels = torch.cat(
927
+ [
928
+ torch.tensor([-100], dtype=torch.long)
929
+ .unsqueeze(0)
930
+ .to(labels.device), # account for the property token
931
+ labels,
932
+ ],
933
+ dim=1,
934
+ ) # ignore index
935
+ labels = labels.to(prediction_scores.device)
936
+
937
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1])
938
+ labels = labels[:, 1:].contiguous().view(-1)
939
+ loss = cross_entropy(shifted_prediction_scores, labels)
940
+
941
+ if not return_dict:
942
+ return (
943
+ loss,
944
+ prediction_scores,
945
+ ) + outputs
946
+
947
+ return BacformerModelOutput(
948
+ loss=loss,
949
+ logits=prediction_scores,
950
+ last_hidden_state=outputs.last_hidden_state,
951
+ attentions=outputs.attentions,
952
+ )
953
+
954
+ def generate(
955
+ self,
956
+ protein_family_ids: torch.LongTensor,
957
+ special_tokens_mask: torch.LongTensor = None,
958
+ token_type_ids: torch.LongTensor = None,
959
+ max_length: int = 6000,
960
+ end_token_id: int = 50000,
961
+ do_sample: bool = False,
962
+ top_k: int = 50,
963
+ top_p: float = 1.0,
964
+ temperature: float = 1.0,
965
+ property_ids: torch.LongTensor = None,
966
+ return_last_hidden_states: bool = False,
967
+ ):
968
+ """
969
+ Generate a sequence of tokens autoregressively from a given prompt.
970
+
971
+ Args:
972
+ protein_family_ids (torch.LongTensor): Tensor of shape (batch, seq_len) with token indices.
973
+ max_length (int): Maximum length of the generated sequence (prompt + newly generated).
974
+ end_token_id (int, optional): Token ID signifying end-of-sequence (END).
975
+ If encountered, generation stops.
976
+ do_sample (bool): Whether to sample from the probability distribution (True)
977
+ or use greedy decoding (False).
978
+ top_k (int): If >0, use top-k filtering in sampling mode.
979
+ top_p (float): If <1.0, use nucleus (top-p) filtering in sampling mode.
980
+ temperature (float): Softmax temperature for scaling logits.
981
+ Higher => more random, lower => more deterministic.
982
+ return_last_hidden_states (bool): If True, return final hidden states as well.
983
+
984
+ Returns
985
+ -------
986
+ torch.LongTensor: The generated token sequence of shape (batch, final_seq_len).
987
+ (Optional) torch.FloatTensor: Final hidden states of shape (batch, final_seq_len, hidden_dim)
988
+ if `return_hidden_states=True`.
989
+ """
990
+ # Default END token
991
+ if end_token_id is None:
992
+ end_token_id = getattr(self, "end_token_id", None)
993
+
994
+ # Switch to eval mode and move input to correct device
995
+ self.eval()
996
+ device = next(self.parameters()).device
997
+ protein_family_ids = protein_family_ids.to(device)
998
+
999
+ # create a special tokens mask if not provided
1000
+ if special_tokens_mask is None:
1001
+ # add a cls token at the beginning
1002
+ protein_family_ids = torch.cat(
1003
+ [torch.tensor([[-100]]).to(device), protein_family_ids],
1004
+ dim=1,
1005
+ )
1006
+ special_tokens_mask = [self.cls_token_id] + [self.config.prot_emb_token_id] * (
1007
+ protein_family_ids.shape[1] - 1
1008
+ )
1009
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.long).to(device)
1010
+
1011
+ # create a token type mask if not provided
1012
+ if token_type_ids is None:
1013
+ token_type_ids = torch.zeros_like(protein_family_ids)
1014
+
1015
+ # Prepare the initial sequence and define max new tokens
1016
+ generated = protein_family_ids.clone()
1017
+ batch_size, prompt_length = generated.shape
1018
+ max_new_tokens = max_length - prompt_length
1019
+ if max_new_tokens <= 0:
1020
+ max_new_tokens = 0
1021
+
1022
+ # Disable gradient calculations for generation
1023
+ with torch.no_grad():
1024
+ for _step in range(max_new_tokens):
1025
+ # Forward pass
1026
+ logits = self.forward(
1027
+ labels=generated,
1028
+ special_tokens_mask=special_tokens_mask,
1029
+ # assume it's all on one chromosome
1030
+ token_type_ids=token_type_ids,
1031
+ property_ids=property_ids,
1032
+ return_dict=True,
1033
+ ).logits
1034
+ # Focus on the last token's logits
1035
+ next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
1036
+
1037
+ # Apply temperature
1038
+ if temperature != 1.0:
1039
+ next_token_logits = next_token_logits / temperature
1040
+
1041
+ # Sampling or greedy?
1042
+ if do_sample:
1043
+ # Top-k filter
1044
+ next_token_logits = top_k_filtering(next_token_logits, top_k=top_k)
1045
+ # Top-p filter
1046
+ next_token_logits = top_p_filtering(next_token_logits, top_p=top_p)
1047
+
1048
+ probs = softmax(next_token_logits, dim=-1)
1049
+ next_token_id = torch.multinomial(probs, num_samples=1)
1050
+ else:
1051
+ # Greedy decoding
1052
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
1053
+
1054
+ # Append predicted token
1055
+ generated = torch.cat([generated, next_token_id], dim=1)
1056
+ special_tokens_mask = torch.cat(
1057
+ [special_tokens_mask, torch.tensor([[self.config.prot_emb_token_id]]).to(generated.device)], dim=1
1058
+ )
1059
+ last_token_type_id = token_type_ids[:, -1].unsqueeze(1)
1060
+ token_type_ids = torch.cat([token_type_ids, last_token_type_id], dim=1)
1061
+
1062
+ # Check for END in all sequences
1063
+ if end_token_id is not None:
1064
+ if (next_token_id.squeeze(1) == end_token_id).all():
1065
+ # If every sequence ended, break early
1066
+ break
1067
+
1068
+ if not return_last_hidden_states:
1069
+ return generated
1070
+
1071
+ # Optionally compute final hidden states
1072
+ if return_last_hidden_states:
1073
+ last_hidden_state = self.forward(
1074
+ labels=generated,
1075
+ special_tokens_mask=special_tokens_mask,
1076
+ token_type_ids=token_type_ids,
1077
+ return_dict=True,
1078
+ ).last_hidden_state
1079
+
1080
+ return generated, last_hidden_state
1081
+
1082
+
1083
+ class BacformerForMaskedGMWithContrastiveLoss(BacformerPreTrainedModel):
1084
+ """Bacformer model with genomic modeling head on top"""
1085
+
1086
+ _tied_weights_keys = ["gm_head.decoder.weight"]
1087
+
1088
+ def __init__(self, config: BacformerConfig):
1089
+ super().__init__(config)
1090
+ self.config = config
1091
+
1092
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1093
+ self.gm_head = BacformerGMHead(config)
1094
+
1095
+ # Initialize weights
1096
+ self.init_weights()
1097
+
1098
+ def forward(
1099
+ self,
1100
+ protein_embeddings: torch.Tensor,
1101
+ special_tokens_mask: torch.Tensor,
1102
+ labels: torch.Tensor = None,
1103
+ token_type_ids: torch.Tensor = None,
1104
+ attention_mask: torch.Tensor = None,
1105
+ return_attn_weights: bool = None,
1106
+ return_dict: Union[bool, None] = None,
1107
+ ) -> Union[BacformerModelOutput, None]:
1108
+ """Forward method for the model."""
1109
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1110
+ return_attn_weights = (
1111
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1112
+ )
1113
+
1114
+ outputs = self.bacformer(
1115
+ protein_embeddings=protein_embeddings,
1116
+ special_tokens_mask=special_tokens_mask,
1117
+ token_type_ids=token_type_ids,
1118
+ attention_mask=attention_mask,
1119
+ return_attn_weights=return_attn_weights,
1120
+ return_dict=return_dict,
1121
+ )
1122
+ last_hidden_state = outputs[0]
1123
+
1124
+ # to speed up the forward pass, let's only consider the masked tokens
1125
+
1126
+ loss = None
1127
+ if labels is not None:
1128
+ # contrastive loss
1129
+ contrastive_loss = compute_contrastive_loss(protein_embeddings, last_hidden_state, special_tokens_mask)
1130
+ # to speed up the forward pass, let's only consider the masked tokens
1131
+ last_hidden_state = last_hidden_state[labels != -100]
1132
+ prediction_scores = self.gm_head(last_hidden_state)
1133
+ labels = labels.to(prediction_scores.device)
1134
+
1135
+ # only considering the masked tokens
1136
+ labels = labels[labels != -100]
1137
+ masked_loss = cross_entropy(prediction_scores, labels)
1138
+ loss = masked_loss + self.config.alpha_contrastive_loss * contrastive_loss
1139
+ else:
1140
+ prediction_scores = self.gm_head(last_hidden_state)
1141
+
1142
+ if not return_dict:
1143
+ return (
1144
+ loss,
1145
+ prediction_scores,
1146
+ ) + outputs
1147
+
1148
+ return BacformerModelOutput(
1149
+ loss=loss,
1150
+ logits=prediction_scores,
1151
+ last_hidden_state=outputs.last_hidden_state,
1152
+ attentions=outputs.attentions,
1153
+ )
1154
+
1155
+
1156
+ class BacformerForProteinClassification(BacformerPreTrainedModel):
1157
+ """Bacformer model with a classification head on top for protein classification tasks."""
1158
+
1159
+ def __init__(self, config: BacformerConfig, benchmark_esm: bool = False):
1160
+ super().__init__(config)
1161
+ self.config = config
1162
+ self.benchmark_esm = benchmark_esm
1163
+
1164
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1165
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1166
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1167
+
1168
+ # Initialize weights and apply final processing
1169
+ self.post_init()
1170
+
1171
+ def forward(
1172
+ self,
1173
+ protein_embeddings: torch.Tensor,
1174
+ special_tokens_mask: torch.Tensor,
1175
+ labels: torch.Tensor = None,
1176
+ token_type_ids: torch.Tensor = None,
1177
+ attention_mask: torch.Tensor = None,
1178
+ return_attn_weights: bool = None,
1179
+ return_dict: Union[bool, None] = None,
1180
+ ) -> Optional[BacformerModelOutput]:
1181
+ """Forward method for the model."""
1182
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1183
+ return_attn_weights = (
1184
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1185
+ )
1186
+
1187
+ if self.benchmark_esm:
1188
+ outputs = [protein_embeddings]
1189
+ else:
1190
+ outputs = self.bacformer(
1191
+ protein_embeddings=protein_embeddings,
1192
+ special_tokens_mask=special_tokens_mask,
1193
+ token_type_ids=token_type_ids,
1194
+ attention_mask=attention_mask,
1195
+ return_attn_weights=return_attn_weights,
1196
+ return_dict=return_dict,
1197
+ )
1198
+
1199
+ last_hidden_state = outputs[0]
1200
+
1201
+ last_hidden_state = self.dropout(last_hidden_state)
1202
+ logits = self.classifier(last_hidden_state)
1203
+
1204
+ loss = None
1205
+ if labels is not None:
1206
+ labels = labels.to(logits.device)
1207
+
1208
+ if self.config.problem_type == "regression":
1209
+ loss = mse_loss(logits, labels)
1210
+ elif self.config.problem_type == "single_label_classification":
1211
+ loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1212
+ elif (
1213
+ self.config.problem_type == "multi_label_classification"
1214
+ or self.config.problem_type == "binary_classification"
1215
+ ):
1216
+ # remove the -100 labels from loss computation
1217
+ mask = torch.ones_like(labels.view(-1)) - (labels.view(-1) == -100.0).float()
1218
+ loss = binary_cross_entropy_with_logits(
1219
+ logits.view(-1), labels.view(-1).type_as(logits), reduction="none"
1220
+ )
1221
+ loss = (loss * mask).sum() / mask.sum()
1222
+
1223
+ if not return_dict:
1224
+ return (
1225
+ loss,
1226
+ None,
1227
+ logits,
1228
+ ) # + outputs
1229
+
1230
+ return BacformerModelOutput(
1231
+ loss=loss,
1232
+ logits=logits,
1233
+ last_hidden_state=last_hidden_state,
1234
+ attentions=outputs.attentions,
1235
+ )
1236
+
1237
+
1238
+ class BacformerForGenomeClassification(BacformerPreTrainedModel):
1239
+ """Bacformer model with a classification head on top for genome classification tasks."""
1240
+
1241
+ def __init__(self, config: BacformerConfig):
1242
+ super().__init__(config)
1243
+ self.config = config
1244
+
1245
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1246
+ self.classifier = BacformerGenomeClassificationHead(config)
1247
+
1248
+ # Initialize weights and apply final processing
1249
+ self.post_init()
1250
+
1251
+ def forward(
1252
+ self,
1253
+ protein_embeddings: torch.Tensor,
1254
+ special_tokens_mask: torch.Tensor,
1255
+ labels: torch.Tensor = None,
1256
+ token_type_ids: torch.Tensor = None,
1257
+ attention_mask: torch.Tensor = None,
1258
+ return_attn_weights: bool = None,
1259
+ return_dict: Union[bool, None] = None,
1260
+ ) -> Optional[BacformerModelOutput]:
1261
+ """Forward method for the model."""
1262
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1263
+ return_attn_weights = (
1264
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1265
+ )
1266
+
1267
+ outputs = self.bacformer(
1268
+ protein_embeddings=protein_embeddings,
1269
+ special_tokens_mask=special_tokens_mask,
1270
+ token_type_ids=token_type_ids,
1271
+ attention_mask=attention_mask,
1272
+ return_attn_weights=return_attn_weights,
1273
+ return_dict=return_dict,
1274
+ )
1275
+ last_hidden_state = outputs[0]
1276
+ logits = self.classifier(last_hidden_state, attention_mask)
1277
+
1278
+ loss = None
1279
+ if labels is not None:
1280
+ labels = labels.to(logits.device)
1281
+
1282
+ if self.config.problem_type == "regression":
1283
+ loss = mse_loss(logits.view(-1), labels.view(-1))
1284
+ elif self.config.problem_type == "binary_classification":
1285
+ loss = binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
1286
+ elif self.config.problem_type == "single_label_classification":
1287
+ loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1288
+ elif self.config.problem_type == "multi_label_classification":
1289
+ loss = binary_cross_entropy_with_logits(logits, labels)
1290
+
1291
+ if not return_dict:
1292
+ return (
1293
+ loss,
1294
+ None,
1295
+ logits,
1296
+ )
1297
+
1298
+ return BacformerModelOutput(
1299
+ loss=loss,
1300
+ logits=logits,
1301
+ last_hidden_state=outputs.last_hidden_state,
1302
+ attentions=outputs.attentions,
1303
+ )
1304
+
1305
+
1306
+ class BacformerForProteinProteinInteraction(BacformerPreTrainedModel):
1307
+ """Bacformer model with a protein-protein interaction head on top."""
1308
+
1309
+ def __init__(self, config: BacformerConfig, benchmark_esm: bool = False):
1310
+ super().__init__(config)
1311
+ self.config = config
1312
+ self.benchmark_esm = benchmark_esm
1313
+ print("Benchmark ESM:", self.benchmark_esm)
1314
+ self.return_attn_weights = config.return_attn_weights
1315
+
1316
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1317
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1318
+ self.dense = nn.Sequential(
1319
+ nn.Linear(config.hidden_size, config.hidden_size),
1320
+ nn.GELU(),
1321
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1322
+ nn.Dropout(0.2),
1323
+ )
1324
+ self.ppi_head = BacformerProteinProteinInteractionHead(
1325
+ in_features=config.hidden_size, prot_emb_idx=config.prot_emb_token_id
1326
+ )
1327
+
1328
+ # Initialize weights and apply final processing
1329
+ self.post_init()
1330
+
1331
+ def forward(
1332
+ self,
1333
+ protein_embeddings: torch.Tensor,
1334
+ special_tokens_mask: torch.Tensor,
1335
+ labels: torch.Tensor = None,
1336
+ token_type_ids: torch.Tensor = None,
1337
+ attention_mask: torch.Tensor = None,
1338
+ return_attn_weights: bool = None,
1339
+ return_dict: Union[bool, None] = None,
1340
+ ) -> Union[OrderedDict, None]: # TODO: change it from token classifier output
1341
+ """Forward method for the model."""
1342
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1343
+
1344
+ if self.benchmark_esm:
1345
+ last_hidden_state = protein_embeddings.squeeze(0)[1:-2, :]
1346
+ else:
1347
+ outputs = self.bacformer(
1348
+ protein_embeddings=protein_embeddings,
1349
+ special_tokens_mask=special_tokens_mask,
1350
+ token_type_ids=token_type_ids,
1351
+ attention_mask=attention_mask,
1352
+ return_attn_weights=False,
1353
+ return_dict=True,
1354
+ )
1355
+ last_hidden_state = outputs.last_hidden_state.squeeze(0)[1:-2, :]
1356
+
1357
+ assert labels.shape[0] == 1, "Batch size should be 1 for protein-protein interaction task"
1358
+
1359
+ last_hidden_state = self.dense(self.dropout(last_hidden_state))
1360
+ last_hidden_state = torch.cat([last_hidden_state[labels[:, 0]], last_hidden_state[labels[:, 1]]], dim=0).mean(
1361
+ dim=0
1362
+ )
1363
+ logits = self.ppi_head(last_hidden_state)
1364
+
1365
+ loss = binary_cross_entropy_with_logits(logits, labels[:, 2].type_as(logits).squeeze(0))
1366
+
1367
+ if not return_dict:
1368
+ return (
1369
+ loss,
1370
+ logits,
1371
+ )
1372
+
1373
+ return BacformerModelOutput(
1374
+ loss=loss,
1375
+ logits=logits,
1376
+ last_hidden_state=outputs.last_hidden_state,
1377
+ attentions=outputs.attentions,
1378
+ )
1379
+
1380
+
1381
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
1382
+ class BacformerPooler(nn.Module):
1383
+ """Pooler for Bacformer model."""
1384
+
1385
+ def __init__(self, config: BacformerConfig):
1386
+ super().__init__()
1387
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1388
+ self.activation = nn.Tanh()
1389
+
1390
+ def forward(self, hidden_states: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor:
1391
+ """Forward method for the pooler."""
1392
+ # We "pool" the model by taking the mean of non-padding tokens
1393
+ padding_mask = padding_mask.to(hidden_states.device) if padding_mask is not None else None
1394
+ if padding_mask is not None:
1395
+ mean_hidden_states = torch.einsum("ijk,ij->ik", hidden_states, padding_mask) / padding_mask.sum(
1396
+ 1
1397
+ ).unsqueeze(1)
1398
+ else:
1399
+ mean_hidden_states = hidden_states.mean(dim=1)
1400
+ pooled_output = self.dense(mean_hidden_states)
1401
+ pooled_output = self.activation(pooled_output)
1402
+ return pooled_output
1403
+
1404
+
1405
+ class BacformerGMHead(nn.Module):
1406
+ """Bacformer Head for genomic modeling."""
1407
+
1408
+ def __init__(self, config):
1409
+ super().__init__()
1410
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1411
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1412
+
1413
+ # add 1 to the condfig.protein_clusters_vocab_size to account for the end token
1414
+ self.decoder = nn.Linear(config.hidden_size, config.protein_clusters_vocab_size + 1, bias=False)
1415
+ self.bias = nn.Parameter(torch.zeros(config.protein_clusters_vocab_size + 1))
1416
+
1417
+ def forward(self, features, **kwargs):
1418
+ """Forward method for the head."""
1419
+ x = self.dense(features)
1420
+ x = gelu(x)
1421
+ x = self.layer_norm(x)
1422
+
1423
+ # project back to nr of labels with bias
1424
+ x = self.decoder(x) + self.bias
1425
+ return x
1426
+
1427
+
1428
+ class BacformerGenomeClassificationHead(nn.Module):
1429
+ """Head for genome-level classification tasks."""
1430
+
1431
+ def __init__(self, config: BacformerConfig):
1432
+ super().__init__()
1433
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1434
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1435
+
1436
+ def forward(self, features: torch.Tensor, padding_mask: torch.Tensor, **kwargs):
1437
+ """Forward method for the head."""
1438
+ if padding_mask is not None:
1439
+ x = torch.einsum("ijk,ij->ik", features, padding_mask) / padding_mask.sum(1).unsqueeze(1)
1440
+ else:
1441
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1442
+ x = self.dropout(x)
1443
+ x = self.out_proj(x)
1444
+ return x
1445
+
1446
+
1447
+ class BacformerProteinProteinInteractionHead(nn.Module):
1448
+ """Head for protein-protein interaction task at a genome level."""
1449
+
1450
+ def __init__(self, in_features: int, prot_emb_idx: int = 4, bias: bool = True):
1451
+ super().__init__()
1452
+ self.in_features = in_features
1453
+ self.prot_emb_idx = prot_emb_idx
1454
+ self.dropout = nn.Dropout(0.2)
1455
+ self.linear = nn.Linear(in_features, 1, bias=bias)
1456
+
1457
+ def forward(
1458
+ self, hidden_states: torch.Tensor
1459
+ ) -> torch.Tensor: # special_tokens_mask: torch.Tensor, attentions: torch.Tensor):
1460
+ """Forward method for the head."""
1461
+ return self.linear(self.dropout(hidden_states)).squeeze(-1)