macwiatrak commited on
Commit
8c51408
·
verified ·
1 Parent(s): 099cfb6

Upload BacformerForCausalGM

Browse files
config.json CHANGED
@@ -1,10 +1,13 @@
1
  {
2
- "_name_or_path": "/rds/user/mw896/rds-flotolab-9X9gY1OFt4M/projects/bacformer/output-data/complete-genomes/finetuning/runs-causal/30epochs/checkpoint-18750",
3
  "alpha_contrastive_loss": 0.5,
4
  "architectures": [
5
  "BacformerForCausalGM"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
8
  "batch_size": 1,
9
  "ckpt_path": null,
10
  "dataloader_num_workers": 10,
@@ -15,212 +18,14 @@
15
  "hidden_dropout_prob": 0.1,
16
  "hidden_size": 480,
17
  "id2label": {
18
- "0": "LABEL_0",
19
- "1": "LABEL_1",
20
- "2": "LABEL_2",
21
- "3": "LABEL_3",
22
- "4": "LABEL_4",
23
- "5": "LABEL_5",
24
- "6": "LABEL_6",
25
- "7": "LABEL_7",
26
- "8": "LABEL_8",
27
- "9": "LABEL_9",
28
- "10": "LABEL_10",
29
- "11": "LABEL_11",
30
- "12": "LABEL_12",
31
- "13": "LABEL_13",
32
- "14": "LABEL_14",
33
- "15": "LABEL_15",
34
- "16": "LABEL_16",
35
- "17": "LABEL_17",
36
- "18": "LABEL_18",
37
- "19": "LABEL_19",
38
- "20": "LABEL_20",
39
- "21": "LABEL_21",
40
- "22": "LABEL_22",
41
- "23": "LABEL_23",
42
- "24": "LABEL_24",
43
- "25": "LABEL_25",
44
- "26": "LABEL_26",
45
- "27": "LABEL_27",
46
- "28": "LABEL_28",
47
- "29": "LABEL_29",
48
- "30": "LABEL_30",
49
- "31": "LABEL_31",
50
- "32": "LABEL_32",
51
- "33": "LABEL_33",
52
- "34": "LABEL_34",
53
- "35": "LABEL_35",
54
- "36": "LABEL_36",
55
- "37": "LABEL_37",
56
- "38": "LABEL_38",
57
- "39": "LABEL_39",
58
- "40": "LABEL_40",
59
- "41": "LABEL_41",
60
- "42": "LABEL_42",
61
- "43": "LABEL_43",
62
- "44": "LABEL_44",
63
- "45": "LABEL_45",
64
- "46": "LABEL_46",
65
- "47": "LABEL_47",
66
- "48": "LABEL_48",
67
- "49": "LABEL_49",
68
- "50": "LABEL_50",
69
- "51": "LABEL_51",
70
- "52": "LABEL_52",
71
- "53": "LABEL_53",
72
- "54": "LABEL_54",
73
- "55": "LABEL_55",
74
- "56": "LABEL_56",
75
- "57": "LABEL_57",
76
- "58": "LABEL_58",
77
- "59": "LABEL_59",
78
- "60": "LABEL_60",
79
- "61": "LABEL_61",
80
- "62": "LABEL_62",
81
- "63": "LABEL_63",
82
- "64": "LABEL_64",
83
- "65": "LABEL_65",
84
- "66": "LABEL_66",
85
- "67": "LABEL_67",
86
- "68": "LABEL_68",
87
- "69": "LABEL_69",
88
- "70": "LABEL_70",
89
- "71": "LABEL_71",
90
- "72": "LABEL_72",
91
- "73": "LABEL_73",
92
- "74": "LABEL_74",
93
- "75": "LABEL_75",
94
- "76": "LABEL_76",
95
- "77": "LABEL_77",
96
- "78": "LABEL_78",
97
- "79": "LABEL_79",
98
- "80": "LABEL_80",
99
- "81": "LABEL_81",
100
- "82": "LABEL_82",
101
- "83": "LABEL_83",
102
- "84": "LABEL_84",
103
- "85": "LABEL_85",
104
- "86": "LABEL_86",
105
- "87": "LABEL_87",
106
- "88": "LABEL_88",
107
- "89": "LABEL_89",
108
- "90": "LABEL_90",
109
- "91": "LABEL_91",
110
- "92": "LABEL_92",
111
- "93": "LABEL_93",
112
- "94": "LABEL_94",
113
- "95": "LABEL_95",
114
- "96": "LABEL_96",
115
- "97": "LABEL_97",
116
- "98": "LABEL_98",
117
- "99": "LABEL_99"
118
  },
119
  "initializer_range": 0.02,
120
  "input_dir": "/rds/user/mw896/rds-flotolab-9X9gY1OFt4M/projects/bacformer/input-data/eval-genomes/",
121
  "intermediate_size": 1280,
122
  "is_causal_gm": true,
123
  "label2id": {
124
- "LABEL_0": 0,
125
- "LABEL_1": 1,
126
- "LABEL_10": 10,
127
- "LABEL_11": 11,
128
- "LABEL_12": 12,
129
- "LABEL_13": 13,
130
- "LABEL_14": 14,
131
- "LABEL_15": 15,
132
- "LABEL_16": 16,
133
- "LABEL_17": 17,
134
- "LABEL_18": 18,
135
- "LABEL_19": 19,
136
- "LABEL_2": 2,
137
- "LABEL_20": 20,
138
- "LABEL_21": 21,
139
- "LABEL_22": 22,
140
- "LABEL_23": 23,
141
- "LABEL_24": 24,
142
- "LABEL_25": 25,
143
- "LABEL_26": 26,
144
- "LABEL_27": 27,
145
- "LABEL_28": 28,
146
- "LABEL_29": 29,
147
- "LABEL_3": 3,
148
- "LABEL_30": 30,
149
- "LABEL_31": 31,
150
- "LABEL_32": 32,
151
- "LABEL_33": 33,
152
- "LABEL_34": 34,
153
- "LABEL_35": 35,
154
- "LABEL_36": 36,
155
- "LABEL_37": 37,
156
- "LABEL_38": 38,
157
- "LABEL_39": 39,
158
- "LABEL_4": 4,
159
- "LABEL_40": 40,
160
- "LABEL_41": 41,
161
- "LABEL_42": 42,
162
- "LABEL_43": 43,
163
- "LABEL_44": 44,
164
- "LABEL_45": 45,
165
- "LABEL_46": 46,
166
- "LABEL_47": 47,
167
- "LABEL_48": 48,
168
- "LABEL_49": 49,
169
- "LABEL_5": 5,
170
- "LABEL_50": 50,
171
- "LABEL_51": 51,
172
- "LABEL_52": 52,
173
- "LABEL_53": 53,
174
- "LABEL_54": 54,
175
- "LABEL_55": 55,
176
- "LABEL_56": 56,
177
- "LABEL_57": 57,
178
- "LABEL_58": 58,
179
- "LABEL_59": 59,
180
- "LABEL_6": 6,
181
- "LABEL_60": 60,
182
- "LABEL_61": 61,
183
- "LABEL_62": 62,
184
- "LABEL_63": 63,
185
- "LABEL_64": 64,
186
- "LABEL_65": 65,
187
- "LABEL_66": 66,
188
- "LABEL_67": 67,
189
- "LABEL_68": 68,
190
- "LABEL_69": 69,
191
- "LABEL_7": 7,
192
- "LABEL_70": 70,
193
- "LABEL_71": 71,
194
- "LABEL_72": 72,
195
- "LABEL_73": 73,
196
- "LABEL_74": 74,
197
- "LABEL_75": 75,
198
- "LABEL_76": 76,
199
- "LABEL_77": 77,
200
- "LABEL_78": 78,
201
- "LABEL_79": 79,
202
- "LABEL_8": 8,
203
- "LABEL_80": 80,
204
- "LABEL_81": 81,
205
- "LABEL_82": 82,
206
- "LABEL_83": 83,
207
- "LABEL_84": 84,
208
- "LABEL_85": 85,
209
- "LABEL_86": 86,
210
- "LABEL_87": 87,
211
- "LABEL_88": 88,
212
- "LABEL_89": 89,
213
- "LABEL_9": 9,
214
- "LABEL_90": 90,
215
- "LABEL_91": 91,
216
- "LABEL_92": 92,
217
- "LABEL_93": 93,
218
- "LABEL_94": 94,
219
- "LABEL_95": 95,
220
- "LABEL_96": 96,
221
- "LABEL_97": 97,
222
- "LABEL_98": 98,
223
- "LABEL_99": 99
224
  },
225
  "layer_norm_eps": 1e-12,
226
  "logging_steps": 500,
@@ -260,9 +65,9 @@
260
  },
261
  "test": false,
262
  "test_after_train": false,
263
- "torch_dtype": "float32",
264
  "train_subset_prop": 1.0,
265
- "transformers_version": "4.38.2",
266
  "warmup_proportion": 0.1,
267
  "weight_decay": 0.01
268
  }
 
1
  {
 
2
  "alpha_contrastive_loss": 0.5,
3
  "architectures": [
4
  "BacformerForCausalGM"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_bacformer.BacformerConfig",
9
+ "AutoModelForCausalLM": "modeling_bacformer.BacformerForCausalGM"
10
+ },
11
  "batch_size": 1,
12
  "ckpt_path": null,
13
  "dataloader_num_workers": 10,
 
18
  "hidden_dropout_prob": 0.1,
19
  "hidden_size": 480,
20
  "id2label": {
21
+ "0": "LABEL_0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  },
23
  "initializer_range": 0.02,
24
  "input_dir": "/rds/user/mw896/rds-flotolab-9X9gY1OFt4M/projects/bacformer/input-data/eval-genomes/",
25
  "intermediate_size": 1280,
26
  "is_causal_gm": true,
27
  "label2id": {
28
+ "LABEL_0": 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  },
30
  "layer_norm_eps": 1e-12,
31
  "logging_steps": 500,
 
65
  },
66
  "test": false,
67
  "test_after_train": false,
68
+ "torch_dtype": "bfloat16",
69
  "train_subset_prop": 1.0,
70
+ "transformers_version": "4.50.3",
71
  "warmup_proportion": 0.1,
72
  "weight_decay": 0.01
73
  }
configuration_bacformer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ SPECIAL_TOKENS_DICT = {
6
+ "PAD": 0,
7
+ "MASK": 1,
8
+ "CLS": 2,
9
+ "SEP": 3,
10
+ "PROT_EMB": 4,
11
+ "END": 5,
12
+ }
13
+
14
+
15
+ class BacformerConfig(PretrainedConfig):
16
+ """Configuration class to store the configuration of a `BacformerModel`."""
17
+
18
+ model_type = "bacformer"
19
+
20
+ def __init__(
21
+ self,
22
+ num_hidden_layers: int = 6,
23
+ num_attention_heads: int = 8,
24
+ hidden_size: int = 480, # default esm2_t12_35M_UR50D embedding dim
25
+ intermediate_size: int = 1280,
26
+ hidden_dropout_prob: float = 0.1,
27
+ attention_probs_dropout_prob: float = 0.1,
28
+ max_position_embeddings: int = 6000,
29
+ max_token_type_embeddings: int = 1000,
30
+ layer_norm_eps: float = 1e-12,
31
+ initializer_range: float = 0.02,
32
+ pad_token_id: int = SPECIAL_TOKENS_DICT["PAD"],
33
+ mask_token_id: int = SPECIAL_TOKENS_DICT["MASK"],
34
+ prot_emb_token_id: int = SPECIAL_TOKENS_DICT["PROT_EMB"],
35
+ end_token_id: int = SPECIAL_TOKENS_DICT["END"],
36
+ num_special_tokens: int = len(SPECIAL_TOKENS_DICT),
37
+ protein_clusters_vocab_size: int = 50001, # equal to the nr of protein clusters + 1
38
+ num_labels: int = 1, # for downstream tasks
39
+ is_causal_gm: bool = False,
40
+ return_dict: bool = False,
41
+ return_attn_weights: bool = False,
42
+ alpha_contrastive_loss: float = 0.5,
43
+ # only to use in the BacformerForGenomeClassification
44
+ problem_type: Literal[
45
+ "regression", "binary_classification", "single_label_classification", "multi_label_classification"
46
+ ] = "single_label_classification",
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.hidden_size = hidden_size
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_dropout_prob = hidden_dropout_prob
56
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
57
+ self.max_position_embeddings = max_position_embeddings
58
+ self.max_token_type_embeddings = max_token_type_embeddings
59
+ self.layer_norm_eps = layer_norm_eps
60
+ self.initializer_range = initializer_range
61
+ self.pad_token_id = pad_token_id
62
+ self.mask_token_id = mask_token_id
63
+ self.prot_emb_token_id = prot_emb_token_id
64
+ self.end_token_id = end_token_id
65
+ self.num_special_tokens = num_special_tokens
66
+ self.protein_clusters_vocab_size = protein_clusters_vocab_size
67
+ self.num_labels = num_labels
68
+ self.is_causal_gm = is_causal_gm
69
+ self.return_dict = return_dict
70
+ self.return_attn_weights = return_attn_weights
71
+ self.problem_type = problem_type
72
+ self.alpha_contrastive_loss = alpha_contrastive_loss
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b8915be963bf675d706f653734a802570116ddf5c093645baa3f5d88846fd57
3
- size 203428892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79fe7f24e420eeb8ddca0af7bcd868dff687d9b3902a318467d6e143ad3fefba
3
+ size 101724522
modeling_bacformer.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Optional, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn.functional import (
9
+ binary_cross_entropy_with_logits,
10
+ cross_entropy,
11
+ gelu,
12
+ mse_loss,
13
+ scaled_dot_product_attention,
14
+ softmax,
15
+ )
16
+ from transformers import PreTrainedModel
17
+ from transformers.utils import ModelOutput
18
+
19
+ from .configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig
20
+ from .utils_bacformer import compute_contrastive_loss, create_4d_from_2d_attn_mask, top_k_filtering, top_p_filtering
21
+
22
+
23
+ @dataclass
24
+ class BacformerModelOutput(ModelOutput):
25
+ """Base class for outputs of the Bacformer model."""
26
+
27
+ loss: torch.FloatTensor | None = None
28
+ logits: torch.FloatTensor = None
29
+ last_hidden_state: torch.FloatTensor | None = None
30
+ attentions: Union[torch.FloatTensor, None] = None
31
+ pooler_output: torch.FloatTensor | None = None
32
+
33
+
34
+ # Taken from facebookresearch/llama/model.py
35
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
36
+ """Reshape the rotary embeddings for broadcasting."""
37
+ ndim = x.ndim
38
+ assert 0 <= 1 < ndim
39
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
40
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
41
+ return freqs_cis.view(*shape)
42
+
43
+
44
+ # Taken from facebookresearch/llama/model.py
45
+ def apply_rotary_emb(
46
+ xq: torch.Tensor,
47
+ xk: torch.Tensor,
48
+ freqs_cos: torch.Tensor,
49
+ freqs_sin: torch.Tensor,
50
+ ) -> tuple[torch.Tensor, torch.Tensor]:
51
+ """Apply rotary embeddings to the query and key tensors."""
52
+ # reshape xq and xk to match the complex representation
53
+ xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
54
+ xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
55
+
56
+ # reshape freqs_cos and freqs_sin for broadcasting
57
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
58
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
59
+
60
+ # apply rotation using real numbers
61
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
62
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
63
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
64
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
65
+
66
+ # flatten last two dimensions
67
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
68
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
69
+
70
+ return xq_out.type_as(xq), xk_out.type_as(xk)
71
+
72
+
73
+ # Taken from facebookresearch/llama/model.py
74
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
75
+ """Precompute the freqs cis for rotary embeddings."""
76
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
77
+ t = torch.arange(end, device=freqs.device) # type: ignore
78
+ freqs = torch.outer(t, freqs).float() # type: ignore
79
+
80
+ freqs_cos = torch.cos(freqs) # real part
81
+ freqs_sin = torch.sin(freqs) # imaginary part
82
+ return freqs_cos, freqs_sin
83
+
84
+
85
+ def scaled_dot_product_attention_w_attn_weights(
86
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
87
+ ) -> tuple[torch.Tensor, torch.Tensor]:
88
+ """PyTorch Native implementation, modified to return attention weights."""
89
+ L, S = query.size(-2), key.size(-2)
90
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
91
+ attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
92
+ if is_causal:
93
+ assert attn_mask is None
94
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
95
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
96
+ attn_bias.to(query.dtype)
97
+
98
+ if attn_mask is not None:
99
+ if attn_mask.dtype == torch.bool:
100
+ attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
101
+ else:
102
+ attn_bias += attn_mask
103
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
104
+ attn_weight += attn_bias
105
+ attn_weight = torch.softmax(attn_weight, dim=-1)
106
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
107
+ attn_output = attn_weight @ value
108
+ return attn_output, attn_weight
109
+
110
+
111
+ class RotarySelfAttention(nn.Module):
112
+ """Rotary self-attention module."""
113
+
114
+ def __init__(
115
+ self,
116
+ embed_dim: int,
117
+ num_heads: int,
118
+ dropout: float = 0.1,
119
+ ):
120
+ super().__init__()
121
+ self.embed_dim = embed_dim
122
+ self.num_heads = num_heads
123
+ self.dim_head = embed_dim // num_heads
124
+ self.dropout_rate = dropout
125
+
126
+ self.q = nn.Linear(embed_dim, embed_dim, bias=False)
127
+ self.k = nn.Linear(embed_dim, embed_dim, bias=False)
128
+ self.v = nn.Linear(embed_dim, embed_dim, bias=False)
129
+ self.att_proj_linear = nn.Linear(embed_dim, embed_dim)
130
+
131
+ def forward(
132
+ self,
133
+ x: torch.Tensor,
134
+ attn_mask: torch.Tensor,
135
+ freqs_cos: torch.Tensor,
136
+ freqs_sin: torch.Tensor,
137
+ is_causal: bool = False,
138
+ return_attn_weights: bool = False,
139
+ ):
140
+ """Forward pass for the rotary self-attention module."""
141
+ batch_size, seq_len, _ = x.shape
142
+ xq, xk, xv = self.q(x), self.k(x), self.v(x)
143
+ # Reshape for rotary embeddings
144
+ xq = xq.view(batch_size, seq_len, self.num_heads, self.dim_head)
145
+ xk = xk.view(batch_size, seq_len, self.num_heads, self.dim_head)
146
+ xv = xv.view(batch_size, seq_len, self.num_heads, self.dim_head)
147
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
148
+
149
+ # Reshape for attention calculation: (b_sz, n_head, s_len, d_head)
150
+ xq = xq.transpose(1, 2)
151
+ xk = xk.transpose(1, 2)
152
+ xv = xv.transpose(1, 2)
153
+
154
+ attn_weights = None
155
+ if return_attn_weights:
156
+ att, attn_weights = scaled_dot_product_attention_w_attn_weights(
157
+ query=xq,
158
+ key=xk,
159
+ value=xv,
160
+ attn_mask=attn_mask,
161
+ dropout_p=self.dropout_rate if self.training else 0.0,
162
+ is_causal=is_causal,
163
+ )
164
+ else:
165
+ att = scaled_dot_product_attention(
166
+ query=xq,
167
+ key=xk,
168
+ value=xv,
169
+ attn_mask=attn_mask,
170
+ dropout_p=self.dropout_rate if self.training else 0.0,
171
+ is_causal=is_causal,
172
+ )
173
+ # Shape (b_sz, s_len, n_head, d_head)
174
+ out = att.transpose(1, 2).contiguous()
175
+ out = out.view(batch_size, seq_len, self.num_heads * self.dim_head)
176
+
177
+ return self.att_proj_linear(out), attn_weights
178
+
179
+
180
+ class BacformerTransformerLayer(nn.Module):
181
+ """Own implementation of transformer layer which uses pytorch native MHA but returns attention weights"""
182
+
183
+ def __init__(
184
+ self,
185
+ hidden_size: int,
186
+ intermediate_size: int,
187
+ num_attention_heads: int,
188
+ dropout: float = 0.1,
189
+ activation: Literal["gelu", "relu"] = "gelu",
190
+ ):
191
+ super().__init__()
192
+ self.self_mha = RotarySelfAttention(
193
+ embed_dim=hidden_size,
194
+ num_heads=num_attention_heads,
195
+ dropout=dropout,
196
+ )
197
+
198
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
199
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
200
+ self.activation = nn.GELU() if activation == "gelu" else nn.ReLU()
201
+ self.norm1 = nn.LayerNorm(hidden_size)
202
+ self.norm2 = nn.LayerNorm(hidden_size)
203
+ self.dropout1 = nn.Dropout(dropout)
204
+ self.dropout2 = nn.Dropout(dropout)
205
+ self.dropout3 = nn.Dropout(dropout)
206
+
207
+ def forward(
208
+ self,
209
+ hidden_state: torch.Tensor,
210
+ attention_mask: torch.Tensor = None,
211
+ freqs_cos: torch.Tensor = None,
212
+ freqs_sin: torch.Tensor = None,
213
+ return_attn_weights: bool = False,
214
+ is_causal: bool = False,
215
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
216
+ """Forward pass"""
217
+ attn_outputs, attn_weights = self.self_mha(
218
+ hidden_state,
219
+ attn_mask=attention_mask,
220
+ freqs_cos=freqs_cos,
221
+ freqs_sin=freqs_sin,
222
+ return_attn_weights=return_attn_weights,
223
+ is_causal=is_causal,
224
+ )
225
+ x = self.norm1(hidden_state + self.dropout1(attn_outputs))
226
+ ff_output = self.fc2(self.dropout2(self.activation(self.fc1(x))))
227
+ x = self.norm2(x + self.dropout3(ff_output))
228
+ return x, attn_weights
229
+
230
+
231
+ class BacformerTransformerEncoder(nn.Module):
232
+ """Own implementation of Transformer which return attention weights"""
233
+
234
+ def __init__(
235
+ self,
236
+ num_hidden_layers: int,
237
+ hidden_size: int,
238
+ intermediate_size: int,
239
+ num_attention_heads: int,
240
+ dropout: float = 0.1,
241
+ activation: Literal["gelu", "relu"] = "gelu",
242
+ ):
243
+ super().__init__()
244
+
245
+ self.layers = nn.ModuleList(
246
+ [
247
+ BacformerTransformerLayer(
248
+ hidden_size=hidden_size,
249
+ intermediate_size=intermediate_size,
250
+ num_attention_heads=num_attention_heads,
251
+ dropout=dropout,
252
+ activation=activation,
253
+ )
254
+ for _ in range(num_hidden_layers)
255
+ ]
256
+ )
257
+ self.gradient_checkpointing = False
258
+
259
+ def forward(
260
+ self,
261
+ hidden_state: torch.Tensor,
262
+ attention_mask: torch.Tensor = None,
263
+ freqs_cos: torch.Tensor = None,
264
+ freqs_sin: torch.Tensor = None,
265
+ return_attn_weights: bool = False,
266
+ is_causal: bool = False,
267
+ ) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
268
+ """Forward pass"""
269
+ attn_weights_arr = []
270
+ for layer in self.layers:
271
+ if self.gradient_checkpointing and self.training:
272
+ hidden_state, attn_weights = self._gradient_checkpointing_func(
273
+ layer.__call__,
274
+ hidden_state,
275
+ attention_mask,
276
+ freqs_cos,
277
+ freqs_sin,
278
+ return_attn_weights,
279
+ is_causal,
280
+ )
281
+ else:
282
+ hidden_state, attn_weights = layer(
283
+ hidden_state=hidden_state,
284
+ attention_mask=attention_mask,
285
+ freqs_cos=freqs_cos,
286
+ freqs_sin=freqs_sin,
287
+ return_attn_weights=return_attn_weights,
288
+ is_causal=is_causal,
289
+ )
290
+ # keep the attention weights from each layer
291
+ attn_weights_arr.append(attn_weights)
292
+ return hidden_state, attn_weights_arr
293
+
294
+
295
+ class BacformerEmbeddings(nn.Module):
296
+ """Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings."""
297
+
298
+ def __init__(self, config):
299
+ super().__init__()
300
+ self.config = config
301
+ self.linear = nn.Linear(config.hidden_size, config.hidden_size)
302
+
303
+ self.token_type_embeddings = nn.Embedding(
304
+ num_embeddings=config.max_token_type_embeddings + 1,
305
+ embedding_dim=config.hidden_size,
306
+ padding_idx=config.max_token_type_embeddings,
307
+ )
308
+
309
+ self.special_tokens_embeddings = nn.Embedding(
310
+ num_embeddings=config.num_special_tokens,
311
+ embedding_dim=config.hidden_size,
312
+ )
313
+ self.prot_emb_token_id = config.prot_emb_token_id
314
+ self.pad_token_id = config.pad_token_id
315
+
316
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
318
+
319
+ def forward(
320
+ self,
321
+ protein_embeddings: torch.Tensor = None,
322
+ special_tokens_mask: torch.Tensor = None,
323
+ token_type_ids: torch.Tensor = None,
324
+ labels: torch.Tensor = None, # used for causal protein family modeling
325
+ property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property
326
+ ) -> torch.Tensor:
327
+ """Forward pass for protein embeddings."""
328
+ bs, seq_length, dim = protein_embeddings.shape
329
+
330
+ # pass the pooled ESM protein embeddings through a linear layer
331
+ protein_embeddings = self.linear(protein_embeddings)
332
+ protein_embeddings = torch.where(
333
+ special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
334
+ protein_embeddings,
335
+ self.special_tokens_embeddings(special_tokens_mask),
336
+ )
337
+
338
+ if token_type_ids is not None:
339
+ protein_embeddings += self.token_type_embeddings(token_type_ids)
340
+
341
+ protein_embeddings = self.LayerNorm(protein_embeddings)
342
+ protein_embeddings = self.dropout(protein_embeddings)
343
+ return protein_embeddings
344
+
345
+
346
+ class BacformerProteinFamilyEmbeddings(nn.Module):
347
+ """Construct the protein embeddings from protein family tokens, special tokens and sequence type embeddings."""
348
+
349
+ def __init__(
350
+ self,
351
+ config,
352
+ protein_family_embeddings: torch.Tensor = None,
353
+ token_type_embeddings: torch.Tensor = None,
354
+ special_tokens_embeddings: torch.Tensor = None,
355
+ n_conditional_properties: int = None,
356
+ ):
357
+ super().__init__()
358
+ self.config = config
359
+
360
+ if protein_family_embeddings is not None:
361
+ self.protein_family_embeddings = nn.Embedding.from_pretrained(
362
+ protein_family_embeddings,
363
+ freeze=False,
364
+ padding_idx=config.pad_token_id,
365
+ )
366
+ else:
367
+ self.protein_family_embeddings = nn.Embedding(
368
+ num_embeddings=config.protein_clusters_vocab_size + 1,
369
+ embedding_dim=config.hidden_size,
370
+ padding_idx=config.pad_token_id,
371
+ )
372
+
373
+ if token_type_embeddings is not None:
374
+ self.token_type_embeddings = nn.Embedding.from_pretrained(
375
+ token_type_embeddings,
376
+ freeze=False,
377
+ padding_idx=config.max_token_type_embeddings,
378
+ )
379
+ else:
380
+ self.token_type_embeddings = nn.Embedding(
381
+ num_embeddings=config.max_token_type_embeddings + 1,
382
+ embedding_dim=config.hidden_size,
383
+ padding_idx=config.max_token_type_embeddings,
384
+ )
385
+
386
+ if special_tokens_embeddings is not None:
387
+ self.special_tokens_embeddings = nn.Embedding.from_pretrained(
388
+ special_tokens_embeddings,
389
+ freeze=False,
390
+ padding_idx=config.pad_token_id,
391
+ )
392
+ else:
393
+ self.special_tokens_embeddings = nn.Embedding(
394
+ num_embeddings=config.num_special_tokens,
395
+ embedding_dim=config.hidden_size,
396
+ padding_idx=config.pad_token_id,
397
+ )
398
+
399
+ # add layer for conditional properties
400
+ if n_conditional_properties is not None:
401
+ self.conditional_properties_layer = nn.Embedding(n_conditional_properties, config.hidden_size)
402
+
403
+ self.prot_emb_token_id = config.prot_emb_token_id
404
+ self.pad_token_id = config.pad_token_id
405
+
406
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
407
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
408
+
409
+ def forward(
410
+ self,
411
+ protein_embeddings: torch.Tensor = None,
412
+ special_tokens_mask: torch.Tensor = None,
413
+ token_type_ids: torch.Tensor = None,
414
+ labels: torch.Tensor = None, # used for causal protein family modeling
415
+ property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property
416
+ ) -> torch.Tensor:
417
+ """Forward pass for protein embeddings."""
418
+ # pass the pooled ESM protein embeddings through a linear layer
419
+ # replace -100 with pad_token_id
420
+ labels[labels == -100] = self.pad_token_id
421
+ protein_embeddings = self.protein_family_embeddings(labels)
422
+
423
+ bs, seq_length, dim = protein_embeddings.shape
424
+ protein_embeddings = torch.where(
425
+ special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
426
+ protein_embeddings,
427
+ self.special_tokens_embeddings(special_tokens_mask),
428
+ )
429
+
430
+ if token_type_ids is not None:
431
+ protein_embeddings += self.token_type_embeddings(token_type_ids)
432
+
433
+ if property_ids is not None:
434
+ # get the embeddings for the conditional properties
435
+ property_embedding = self.conditional_properties_layer(property_ids).unsqueeze(1)
436
+ # concatenate the protein embeddings with the conditional properties embeddings
437
+ # property embeddings are added to the beginning of the protein embeddings after the CLS token
438
+ protein_embeddings = torch.cat(
439
+ [
440
+ protein_embeddings[:, :1, :], # CLS token
441
+ property_embedding, # conditional properties embeddings
442
+ protein_embeddings[:, 1:, :],
443
+ ], # protein embeddings
444
+ dim=1,
445
+ )
446
+
447
+ protein_embeddings = self.LayerNorm(protein_embeddings)
448
+ protein_embeddings = self.dropout(protein_embeddings)
449
+ return protein_embeddings
450
+
451
+
452
+ class BacformerEncoder(nn.Module):
453
+ """Bacformer encoder model"""
454
+
455
+ def __init__(self, config):
456
+ super().__init__()
457
+ self.config = config
458
+
459
+ self.encoder = BacformerTransformerEncoder(
460
+ num_hidden_layers=config.num_hidden_layers,
461
+ hidden_size=config.hidden_size,
462
+ num_attention_heads=config.num_attention_heads,
463
+ intermediate_size=config.intermediate_size,
464
+ activation="gelu",
465
+ dropout=config.attention_probs_dropout_prob,
466
+ )
467
+
468
+ # Note that config.max_position_embeddings is multiplied by 1.5 because the token limit for the Bacformer of
469
+ # models is 6000. Adding this multiplier instead of using 6000 directly allows for dynamism of token
470
+ # lengths while training or fine-tuning.
471
+ freqs_cos, freqs_sin = precompute_freqs_cis(
472
+ config.hidden_size // config.num_attention_heads, int(config.max_position_embeddings * 1.5)
473
+ )
474
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
475
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
476
+
477
+ def forward(
478
+ self,
479
+ hidden_states: torch.Tensor,
480
+ attention_mask: torch.Tensor = None,
481
+ return_attn_weights: Union[bool, None] = None,
482
+ is_causal: bool = False,
483
+ ) -> tuple[torch.Tensor, list[torch.Tensor | None]]:
484
+ """Pass the input through the encoder layers in turn.
485
+
486
+ Args:
487
+ hidden_states: hidden states from the BacformerEmbeddings layer
488
+ attention_mask: mask for the attention in the transformer
489
+ """
490
+ return_attn_weights = (
491
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
492
+ )
493
+ bs, seq_len, _ = hidden_states.shape
494
+ last_hidden_state, attn_weights = self.encoder(
495
+ hidden_state=hidden_states,
496
+ attention_mask=attention_mask,
497
+ freqs_cos=self.freqs_cos[:seq_len, :],
498
+ freqs_sin=self.freqs_sin[:seq_len, :],
499
+ return_attn_weights=return_attn_weights,
500
+ is_causal=is_causal,
501
+ )
502
+ return last_hidden_state, attn_weights
503
+
504
+
505
+ class BacformerPreTrainedModel(PreTrainedModel):
506
+ """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models."""
507
+
508
+ config_class = BacformerConfig
509
+ base_model_prefix = "bacformer"
510
+ supports_gradient_checkpointing = True
511
+ _no_split_modules = ["BacformerEmbeddings", "BacformerTransformerLayer"]
512
+
513
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
514
+ def _init_weights(self, module):
515
+ """Initialize the weights"""
516
+ if isinstance(module, nn.Linear):
517
+ # Slightly different from the TF version which uses truncated_normal for initialization
518
+ # cf https://github.com/pytorch/pytorch/pull/5617
519
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
520
+ if module.bias is not None:
521
+ module.bias.data.zero_()
522
+ elif isinstance(module, nn.Embedding):
523
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
524
+ if module.padding_idx is not None:
525
+ module.weight.data[module.padding_idx].zero_()
526
+ elif isinstance(module, nn.LayerNorm):
527
+ module.bias.data.zero_()
528
+ module.weight.data.fill_(1.0)
529
+
530
+
531
+ class BacformerModel(BacformerPreTrainedModel):
532
+ """Bacformer model."""
533
+
534
+ def __init__(self, config: BacformerConfig, add_pooling_layer: bool = False):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = BacformerEmbeddings(config)
539
+ self.encoder = BacformerEncoder(config)
540
+
541
+ self.pooler = BacformerPooler(config) if add_pooling_layer else None
542
+
543
+ # Initialize weights and apply final processing
544
+ self.post_init()
545
+
546
+ def forward(
547
+ self,
548
+ protein_embeddings: torch.Tensor = None,
549
+ special_tokens_mask: torch.Tensor = None,
550
+ token_type_ids: torch.Tensor = None,
551
+ attention_mask: torch.Tensor = None,
552
+ labels: torch.Tensor = None,
553
+ property_ids: torch.Tensor = None,
554
+ return_attn_weights: bool = False,
555
+ return_dict: Union[bool, None] = None,
556
+ is_causal: bool = False,
557
+ ) -> Optional[BacformerModelOutput]:
558
+ """Forward method for the model."""
559
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
560
+ # get embeddings
561
+ protein_embeddings = self.embeddings(
562
+ protein_embeddings=protein_embeddings,
563
+ labels=labels,
564
+ special_tokens_mask=special_tokens_mask,
565
+ token_type_ids=token_type_ids,
566
+ property_ids=property_ids,
567
+ )
568
+
569
+ # create 3D attention mask from 2D if not doing causal GM
570
+ if attention_mask is not None and not is_causal:
571
+ attention_mask = create_4d_from_2d_attn_mask(
572
+ attn_mask=attention_mask, num_attn_heads=self.config.num_attention_heads
573
+ ).bool()
574
+
575
+ last_hidden_state, attentions = self.encoder(
576
+ hidden_states=protein_embeddings,
577
+ attention_mask=attention_mask,
578
+ return_attn_weights=return_attn_weights,
579
+ is_causal=is_causal,
580
+ )
581
+ pooler_output = (
582
+ self.pooler(hidden_states=last_hidden_state, padding_mask=attention_mask)
583
+ if self.pooler is not None
584
+ else None
585
+ )
586
+
587
+ if not return_dict:
588
+ return (last_hidden_state, pooler_output, attentions)
589
+
590
+ return BacformerModelOutput(
591
+ last_hidden_state=last_hidden_state,
592
+ pooler_output=pooler_output,
593
+ attentions=attentions,
594
+ )
595
+
596
+
597
+ class BacformerForCausalGM(BacformerPreTrainedModel):
598
+ """Bacformer model with genomic modeling head on top"""
599
+
600
+ _tied_weights_keys = ["gm_head.decoder.weight"]
601
+
602
+ def __init__(self, config: BacformerConfig):
603
+ super().__init__(config)
604
+ self.config = config
605
+
606
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
607
+ self.gm_head = BacformerGMHead(config)
608
+
609
+ # Initialize weights
610
+ self.init_weights()
611
+
612
+ def forward(
613
+ self,
614
+ protein_embeddings: torch.Tensor,
615
+ special_tokens_mask: torch.Tensor,
616
+ labels: torch.Tensor = None,
617
+ token_type_ids: torch.Tensor = None,
618
+ attention_mask: torch.Tensor = None,
619
+ return_attn_weights: bool = None,
620
+ return_dict: Union[bool, None] = None,
621
+ ) -> Optional[BacformerModelOutput]:
622
+ """Forward method for the model."""
623
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
624
+ return_attn_weights = (
625
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
626
+ )
627
+
628
+ outputs = self.bacformer(
629
+ protein_embeddings=protein_embeddings,
630
+ special_tokens_mask=special_tokens_mask,
631
+ token_type_ids=token_type_ids,
632
+ attention_mask=None, # attention mechanism handles the causal mask
633
+ return_attn_weights=return_attn_weights,
634
+ return_dict=return_dict,
635
+ is_causal=True,
636
+ )
637
+ last_hidden_state = outputs[0]
638
+ prediction_scores = self.gm_head(last_hidden_state)
639
+
640
+ loss = None
641
+ if labels is not None:
642
+ labels = labels.to(prediction_scores.device)
643
+
644
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1])
645
+ labels = labels[:, 1:].contiguous().view(-1)
646
+ loss = cross_entropy(shifted_prediction_scores, labels)
647
+
648
+ if not return_dict:
649
+ return (
650
+ loss,
651
+ prediction_scores,
652
+ ) + outputs
653
+
654
+ return BacformerModelOutput(
655
+ loss=loss,
656
+ logits=prediction_scores,
657
+ last_hidden_state=outputs.last_hidden_state,
658
+ attentions=outputs.attentions,
659
+ )
660
+
661
+
662
+ class BacformerForMaskedGM(BacformerPreTrainedModel):
663
+ """Bacformer model with genomic modeling head on top"""
664
+
665
+ _tied_weights_keys = ["gm_head.decoder.weight"]
666
+
667
+ def __init__(self, config: BacformerConfig):
668
+ super().__init__(config)
669
+ self.config = config
670
+
671
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
672
+ self.gm_head = BacformerGMHead(config)
673
+
674
+ # Initialize weights
675
+ self.init_weights()
676
+
677
+ def forward(
678
+ self,
679
+ protein_embeddings: torch.Tensor,
680
+ special_tokens_mask: torch.Tensor,
681
+ labels: torch.Tensor = None,
682
+ token_type_ids: torch.Tensor = None,
683
+ attention_mask: torch.Tensor = None,
684
+ return_attn_weights: bool = None,
685
+ return_dict: Union[bool, None] = None,
686
+ ) -> Union[BacformerModelOutput, None]:
687
+ """Forward method for the model."""
688
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
689
+ return_attn_weights = (
690
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
691
+ )
692
+
693
+ outputs = self.bacformer(
694
+ protein_embeddings=protein_embeddings,
695
+ special_tokens_mask=special_tokens_mask,
696
+ token_type_ids=token_type_ids,
697
+ attention_mask=attention_mask,
698
+ return_attn_weights=return_attn_weights,
699
+ return_dict=return_dict,
700
+ )
701
+ last_hidden_state = outputs[0]
702
+
703
+ # to speed up the forward pass, let's only consider the masked tokens
704
+
705
+ loss = None
706
+ if labels is not None:
707
+ # to speed up the forward pass, let's only consider the masked tokens
708
+ last_hidden_state = last_hidden_state[labels != -100]
709
+ prediction_scores = self.gm_head(last_hidden_state)
710
+ labels = labels.to(prediction_scores.device)
711
+
712
+ ### notes
713
+ # use the labels to get -100 for non-masked tokens
714
+ # do not use special_tokens_mask
715
+ # check how the labels are constructed
716
+
717
+ # only considering the masked tokens
718
+ labels = labels[labels != -100]
719
+ loss = cross_entropy(prediction_scores, labels)
720
+ else:
721
+ prediction_scores = self.gm_head(last_hidden_state)
722
+
723
+ if not return_dict:
724
+ return (
725
+ loss,
726
+ prediction_scores,
727
+ ) + outputs
728
+
729
+ return BacformerModelOutput(
730
+ loss=loss,
731
+ logits=prediction_scores,
732
+ last_hidden_state=outputs.last_hidden_state,
733
+ attentions=outputs.attentions,
734
+ )
735
+
736
+
737
+ class BacformerForCausalProteinFamilyModeling(BacformerPreTrainedModel):
738
+ """Bacformer model for causal modeling of protein families. Using protein family as tokens rather than protein embeddings"""
739
+
740
+ _tied_weights_keys = ["gm_head.decoder.weight"]
741
+
742
+ def __init__(
743
+ self,
744
+ config: BacformerConfig,
745
+ n_conditional_properties: int = None,
746
+ initialise_from_non_pfm_model: bool = False,
747
+ ):
748
+ super().__init__(config)
749
+ self.config = config
750
+ self.cls_token_id = SPECIAL_TOKENS_DICT["CLS"]
751
+
752
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
753
+ self.gm_head = BacformerGMHead(config)
754
+
755
+ if initialise_from_non_pfm_model:
756
+ # Initialize weights
757
+ self.init_weights()
758
+ # overwrite the embeddings with the pretrained
759
+ # protein family embeddings from the decoder of the GM Head
760
+ self.bacformer.embeddings = BacformerProteinFamilyEmbeddings(
761
+ config,
762
+ protein_family_embeddings=self.gm_head.decoder.weight,
763
+ token_type_embeddings=self.bacformer.embeddings.token_type_embeddings.weight,
764
+ special_tokens_embeddings=self.bacformer.embeddings.special_tokens_embeddings.weight,
765
+ n_conditional_properties=n_conditional_properties,
766
+ )
767
+ else:
768
+ self.bacformer.embeddings = BacformerProteinFamilyEmbeddings(
769
+ config,
770
+ n_conditional_properties=n_conditional_properties,
771
+ )
772
+ self.init_weights()
773
+
774
+ def forward(
775
+ self,
776
+ labels: torch.Tensor = None,
777
+ special_tokens_mask: torch.Tensor = None,
778
+ token_type_ids: torch.Tensor = None,
779
+ property_ids: torch.Tensor = None,
780
+ return_attn_weights: bool = None,
781
+ return_dict: Union[bool, None] = None,
782
+ ) -> Optional[BacformerModelOutput]:
783
+ """Forward method for the model."""
784
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
785
+ return_attn_weights = (
786
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
787
+ )
788
+
789
+ outputs = self.bacformer(
790
+ protein_embeddings=None,
791
+ labels=labels,
792
+ special_tokens_mask=special_tokens_mask,
793
+ token_type_ids=token_type_ids,
794
+ property_ids=property_ids,
795
+ return_attn_weights=return_attn_weights,
796
+ return_dict=return_dict,
797
+ is_causal=True,
798
+ )
799
+ last_hidden_state = outputs[0]
800
+ prediction_scores = self.gm_head(last_hidden_state)
801
+
802
+ loss = None
803
+ if labels is not None:
804
+ if property_ids is not None:
805
+ labels = torch.cat(
806
+ [
807
+ torch.tensor([-100], dtype=torch.long)
808
+ .unsqueeze(0)
809
+ .to(labels.device), # account for the property token
810
+ labels,
811
+ ],
812
+ dim=1,
813
+ ) # ignore index
814
+ labels = labels.to(prediction_scores.device)
815
+
816
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1])
817
+ labels = labels[:, 1:].contiguous().view(-1)
818
+ loss = cross_entropy(shifted_prediction_scores, labels)
819
+
820
+ if not return_dict:
821
+ return (
822
+ loss,
823
+ prediction_scores,
824
+ ) + outputs
825
+
826
+ return BacformerModelOutput(
827
+ loss=loss,
828
+ logits=prediction_scores,
829
+ last_hidden_state=outputs.last_hidden_state,
830
+ attentions=outputs.attentions,
831
+ )
832
+
833
+ def generate(
834
+ self,
835
+ protein_family_ids: torch.LongTensor,
836
+ special_tokens_mask: torch.LongTensor = None,
837
+ token_type_ids: torch.LongTensor = None,
838
+ max_length: int = 6000,
839
+ end_token_id: int = 50000,
840
+ do_sample: bool = False,
841
+ top_k: int = 50,
842
+ top_p: float = 1.0,
843
+ temperature: float = 1.0,
844
+ property_ids: torch.LongTensor = None,
845
+ return_last_hidden_states: bool = False,
846
+ ):
847
+ """
848
+ Generate a sequence of tokens autoregressively from a given prompt.
849
+
850
+ Args:
851
+ protein_family_ids (torch.LongTensor): Tensor of shape (batch, seq_len) with token indices.
852
+ max_length (int): Maximum length of the generated sequence (prompt + newly generated).
853
+ end_token_id (int, optional): Token ID signifying end-of-sequence (END).
854
+ If encountered, generation stops.
855
+ do_sample (bool): Whether to sample from the probability distribution (True)
856
+ or use greedy decoding (False).
857
+ top_k (int): If >0, use top-k filtering in sampling mode.
858
+ top_p (float): If <1.0, use nucleus (top-p) filtering in sampling mode.
859
+ temperature (float): Softmax temperature for scaling logits.
860
+ Higher => more random, lower => more deterministic.
861
+ return_last_hidden_states (bool): If True, return final hidden states as well.
862
+
863
+ Returns
864
+ -------
865
+ torch.LongTensor: The generated token sequence of shape (batch, final_seq_len).
866
+ (Optional) torch.FloatTensor: Final hidden states of shape (batch, final_seq_len, hidden_dim)
867
+ if `return_hidden_states=True`.
868
+ """
869
+ # Default END token
870
+ if end_token_id is None:
871
+ end_token_id = getattr(self, "end_token_id", None)
872
+
873
+ # Switch to eval mode and move input to correct device
874
+ self.eval()
875
+ device = next(self.parameters()).device
876
+ protein_family_ids = protein_family_ids.to(device)
877
+
878
+ # create a special tokens mask if not provided
879
+ if special_tokens_mask is None:
880
+ # add a cls token at the beginning
881
+ protein_family_ids = torch.cat(
882
+ [torch.tensor([[-100]]).to(device), protein_family_ids],
883
+ dim=1,
884
+ )
885
+ special_tokens_mask = [self.cls_token_id] + [self.config.prot_emb_token_id] * (
886
+ protein_family_ids.shape[1] - 1
887
+ )
888
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.long).to(device)
889
+
890
+ # create a token type mask if not provided
891
+ if token_type_ids is None:
892
+ token_type_ids = torch.zeros_like(protein_family_ids)
893
+
894
+ # Prepare the initial sequence and define max new tokens
895
+ generated = protein_family_ids.clone()
896
+ batch_size, prompt_length = generated.shape
897
+ max_new_tokens = max_length - prompt_length
898
+ if max_new_tokens <= 0:
899
+ max_new_tokens = 0
900
+
901
+ # Disable gradient calculations for generation
902
+ with torch.no_grad():
903
+ for _step in range(max_new_tokens):
904
+ # Forward pass
905
+ logits = self.forward(
906
+ labels=generated,
907
+ special_tokens_mask=special_tokens_mask,
908
+ # assume it's all on one chromosome
909
+ token_type_ids=token_type_ids,
910
+ property_ids=property_ids,
911
+ return_dict=True,
912
+ ).logits
913
+ # Focus on the last token's logits
914
+ next_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
915
+
916
+ # Apply temperature
917
+ if temperature != 1.0:
918
+ next_token_logits = next_token_logits / temperature
919
+
920
+ # Sampling or greedy?
921
+ if do_sample:
922
+ # Top-k filter
923
+ next_token_logits = top_k_filtering(next_token_logits, top_k=top_k)
924
+ # Top-p filter
925
+ next_token_logits = top_p_filtering(next_token_logits, top_p=top_p)
926
+
927
+ probs = softmax(next_token_logits, dim=-1)
928
+ next_token_id = torch.multinomial(probs, num_samples=1)
929
+ else:
930
+ # Greedy decoding
931
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
932
+
933
+ # Append predicted token
934
+ generated = torch.cat([generated, next_token_id], dim=1)
935
+ special_tokens_mask = torch.cat(
936
+ [special_tokens_mask, torch.tensor([[self.config.prot_emb_token_id]]).to(generated.device)], dim=1
937
+ )
938
+ last_token_type_id = token_type_ids[:, -1].unsqueeze(1)
939
+ token_type_ids = torch.cat([token_type_ids, last_token_type_id], dim=1)
940
+
941
+ # Check for END in all sequences
942
+ if end_token_id is not None:
943
+ if (next_token_id.squeeze(1) == end_token_id).all():
944
+ # If every sequence ended, break early
945
+ break
946
+
947
+ if not return_last_hidden_states:
948
+ return generated
949
+
950
+ # Optionally compute final hidden states
951
+ if return_last_hidden_states:
952
+ last_hidden_state = self.forward(
953
+ labels=generated,
954
+ special_tokens_mask=special_tokens_mask,
955
+ token_type_ids=token_type_ids,
956
+ return_dict=True,
957
+ ).last_hidden_state
958
+
959
+ return generated, last_hidden_state
960
+
961
+
962
+ class BacformerForMaskedGMWithContrastiveLoss(BacformerPreTrainedModel):
963
+ """Bacformer model with genomic modeling head on top"""
964
+
965
+ _tied_weights_keys = ["gm_head.decoder.weight"]
966
+
967
+ def __init__(self, config: BacformerConfig):
968
+ super().__init__(config)
969
+ self.config = config
970
+
971
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
972
+ self.gm_head = BacformerGMHead(config)
973
+
974
+ # Initialize weights
975
+ self.init_weights()
976
+
977
+ def forward(
978
+ self,
979
+ protein_embeddings: torch.Tensor,
980
+ special_tokens_mask: torch.Tensor,
981
+ labels: torch.Tensor = None,
982
+ token_type_ids: torch.Tensor = None,
983
+ attention_mask: torch.Tensor = None,
984
+ return_attn_weights: bool = None,
985
+ return_dict: Union[bool, None] = None,
986
+ ) -> Union[BacformerModelOutput, None]:
987
+ """Forward method for the model."""
988
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
989
+ return_attn_weights = (
990
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
991
+ )
992
+
993
+ outputs = self.bacformer(
994
+ protein_embeddings=protein_embeddings,
995
+ special_tokens_mask=special_tokens_mask,
996
+ token_type_ids=token_type_ids,
997
+ attention_mask=attention_mask,
998
+ return_attn_weights=return_attn_weights,
999
+ return_dict=return_dict,
1000
+ )
1001
+ last_hidden_state = outputs[0]
1002
+
1003
+ # to speed up the forward pass, let's only consider the masked tokens
1004
+
1005
+ loss = None
1006
+ if labels is not None:
1007
+ # contrastive loss
1008
+ contrastive_loss = compute_contrastive_loss(protein_embeddings, last_hidden_state, special_tokens_mask)
1009
+ # to speed up the forward pass, let's only consider the masked tokens
1010
+ last_hidden_state = last_hidden_state[labels != -100]
1011
+ prediction_scores = self.gm_head(last_hidden_state)
1012
+ labels = labels.to(prediction_scores.device)
1013
+
1014
+ # only considering the masked tokens
1015
+ labels = labels[labels != -100]
1016
+ masked_loss = cross_entropy(prediction_scores, labels)
1017
+ loss = masked_loss + self.config.alpha_contrastive_loss * contrastive_loss
1018
+ else:
1019
+ prediction_scores = self.gm_head(last_hidden_state)
1020
+
1021
+ if not return_dict:
1022
+ return (
1023
+ loss,
1024
+ prediction_scores,
1025
+ ) + outputs
1026
+
1027
+ return BacformerModelOutput(
1028
+ loss=loss,
1029
+ logits=prediction_scores,
1030
+ last_hidden_state=outputs.last_hidden_state,
1031
+ attentions=outputs.attentions,
1032
+ )
1033
+
1034
+
1035
+ class BacformerForProteinClassification(BacformerPreTrainedModel):
1036
+ """Bacformer model with a classification head on top for protein classification tasks."""
1037
+
1038
+ def __init__(self, config: BacformerConfig, benchmark_esm: bool = False):
1039
+ super().__init__(config)
1040
+ self.config = config
1041
+ self.benchmark_esm = benchmark_esm
1042
+
1043
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1044
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1045
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1046
+
1047
+ # Initialize weights and apply final processing
1048
+ self.post_init()
1049
+
1050
+ def forward(
1051
+ self,
1052
+ protein_embeddings: torch.Tensor,
1053
+ special_tokens_mask: torch.Tensor,
1054
+ labels: torch.Tensor = None,
1055
+ token_type_ids: torch.Tensor = None,
1056
+ attention_mask: torch.Tensor = None,
1057
+ return_attn_weights: bool = None,
1058
+ return_dict: Union[bool, None] = None,
1059
+ ) -> Optional[BacformerModelOutput]:
1060
+ """Forward method for the model."""
1061
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1062
+ return_attn_weights = (
1063
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1064
+ )
1065
+
1066
+ if self.benchmark_esm:
1067
+ outputs = [protein_embeddings]
1068
+ else:
1069
+ outputs = self.bacformer(
1070
+ protein_embeddings=protein_embeddings,
1071
+ special_tokens_mask=special_tokens_mask,
1072
+ token_type_ids=token_type_ids,
1073
+ attention_mask=attention_mask,
1074
+ return_attn_weights=return_attn_weights,
1075
+ return_dict=return_dict,
1076
+ )
1077
+
1078
+ last_hidden_state = outputs[0]
1079
+
1080
+ last_hidden_state = self.dropout(last_hidden_state)
1081
+ logits = self.classifier(last_hidden_state)
1082
+
1083
+ loss = None
1084
+ if labels is not None:
1085
+ labels = labels.to(logits.device)
1086
+
1087
+ if self.config.problem_type == "regression":
1088
+ loss = mse_loss(logits, labels)
1089
+ elif self.config.problem_type == "single_label_classification":
1090
+ loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1091
+ elif (
1092
+ self.config.problem_type == "multi_label_classification"
1093
+ or self.config.problem_type == "binary_classification"
1094
+ ):
1095
+ # remove the -100 labels from loss computation
1096
+ mask = torch.ones_like(labels.view(-1)) - (labels.view(-1) == -100.0).float()
1097
+ loss = binary_cross_entropy_with_logits(
1098
+ logits.view(-1), labels.view(-1).type_as(logits), reduction="none"
1099
+ )
1100
+ loss = (loss * mask).sum() / mask.sum()
1101
+
1102
+ if not return_dict:
1103
+ return (
1104
+ loss,
1105
+ None,
1106
+ logits,
1107
+ ) # + outputs
1108
+
1109
+ return BacformerModelOutput(
1110
+ loss=loss,
1111
+ logits=logits,
1112
+ last_hidden_state=last_hidden_state,
1113
+ attentions=outputs.attentions,
1114
+ )
1115
+
1116
+
1117
+ class BacformerForGenomeClassification(BacformerPreTrainedModel):
1118
+ """Bacformer model with a classification head on top for genome classification tasks."""
1119
+
1120
+ def __init__(self, config: BacformerConfig):
1121
+ super().__init__(config)
1122
+ self.config = config
1123
+
1124
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1125
+ self.classifier = BacformerGenomeClassificationHead(config)
1126
+
1127
+ # Initialize weights and apply final processing
1128
+ self.post_init()
1129
+
1130
+ def forward(
1131
+ self,
1132
+ protein_embeddings: torch.Tensor,
1133
+ special_tokens_mask: torch.Tensor,
1134
+ labels: torch.Tensor = None,
1135
+ token_type_ids: torch.Tensor = None,
1136
+ attention_mask: torch.Tensor = None,
1137
+ return_attn_weights: bool = None,
1138
+ return_dict: Union[bool, None] = None,
1139
+ ) -> Optional[BacformerModelOutput]:
1140
+ """Forward method for the model."""
1141
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1142
+ return_attn_weights = (
1143
+ return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights
1144
+ )
1145
+
1146
+ outputs = self.bacformer(
1147
+ protein_embeddings=protein_embeddings,
1148
+ special_tokens_mask=special_tokens_mask,
1149
+ token_type_ids=token_type_ids,
1150
+ attention_mask=attention_mask,
1151
+ return_attn_weights=return_attn_weights,
1152
+ return_dict=return_dict,
1153
+ )
1154
+ last_hidden_state = outputs[0]
1155
+ logits = self.classifier(last_hidden_state, attention_mask)
1156
+
1157
+ loss = None
1158
+ if labels is not None:
1159
+ labels = labels.to(logits.device)
1160
+
1161
+ if self.config.problem_type == "regression":
1162
+ loss = mse_loss(logits.view(-1), labels.view(-1))
1163
+ elif self.config.problem_type == "binary_classification":
1164
+ loss = binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
1165
+ elif self.config.problem_type == "single_label_classification":
1166
+ loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
1167
+ elif self.config.problem_type == "multi_label_classification":
1168
+ loss = binary_cross_entropy_with_logits(logits, labels)
1169
+
1170
+ if not return_dict:
1171
+ return (
1172
+ loss,
1173
+ None,
1174
+ logits,
1175
+ )
1176
+
1177
+ return BacformerModelOutput(
1178
+ loss=loss,
1179
+ logits=logits,
1180
+ last_hidden_state=outputs.last_hidden_state,
1181
+ attentions=outputs.attentions,
1182
+ )
1183
+
1184
+
1185
+ class BacformerForProteinProteinInteraction(BacformerPreTrainedModel):
1186
+ """Bacformer model with a protein-protein interaction head on top."""
1187
+
1188
+ def __init__(self, config: BacformerConfig, benchmark_esm: bool = False):
1189
+ super().__init__(config)
1190
+ self.config = config
1191
+ self.benchmark_esm = benchmark_esm
1192
+ print("Benchmark ESM:", self.benchmark_esm)
1193
+ self.return_attn_weights = config.return_attn_weights
1194
+
1195
+ self.bacformer = BacformerModel(config, add_pooling_layer=False)
1196
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1197
+ self.dense = nn.Sequential(
1198
+ nn.Linear(config.hidden_size, config.hidden_size),
1199
+ nn.GELU(),
1200
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
1201
+ nn.Dropout(0.2),
1202
+ )
1203
+ self.ppi_head = BacformerProteinProteinInteractionHead(
1204
+ in_features=config.hidden_size, prot_emb_idx=config.prot_emb_token_id
1205
+ )
1206
+
1207
+ # Initialize weights and apply final processing
1208
+ self.post_init()
1209
+
1210
+ def forward(
1211
+ self,
1212
+ protein_embeddings: torch.Tensor,
1213
+ special_tokens_mask: torch.Tensor,
1214
+ labels: torch.Tensor = None,
1215
+ token_type_ids: torch.Tensor = None,
1216
+ attention_mask: torch.Tensor = None,
1217
+ return_attn_weights: bool = None,
1218
+ return_dict: Union[bool, None] = None,
1219
+ ) -> Union[OrderedDict, None]: # TODO: change it from token classifier output
1220
+ """Forward method for the model."""
1221
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1222
+
1223
+ if self.benchmark_esm:
1224
+ last_hidden_state = protein_embeddings.squeeze(0)[1:-2, :]
1225
+ else:
1226
+ outputs = self.bacformer(
1227
+ protein_embeddings=protein_embeddings,
1228
+ special_tokens_mask=special_tokens_mask,
1229
+ token_type_ids=token_type_ids,
1230
+ attention_mask=attention_mask,
1231
+ return_attn_weights=False,
1232
+ return_dict=True,
1233
+ )
1234
+ last_hidden_state = outputs.last_hidden_state.squeeze(0)[1:-2, :]
1235
+
1236
+ assert labels.shape[0] == 1, "Batch size should be 1 for protein-protein interaction task"
1237
+
1238
+ last_hidden_state = self.dense(self.dropout(last_hidden_state))
1239
+ last_hidden_state = torch.cat([last_hidden_state[labels[:, 0]], last_hidden_state[labels[:, 1]]], dim=0).mean(
1240
+ dim=0
1241
+ )
1242
+ logits = self.ppi_head(last_hidden_state)
1243
+
1244
+ loss = binary_cross_entropy_with_logits(logits, labels[:, 2].type_as(logits).squeeze(0))
1245
+
1246
+ if not return_dict:
1247
+ return (
1248
+ loss,
1249
+ logits,
1250
+ )
1251
+
1252
+ return BacformerModelOutput(
1253
+ loss=loss,
1254
+ logits=logits,
1255
+ last_hidden_state=outputs.last_hidden_state,
1256
+ attentions=outputs.attentions,
1257
+ )
1258
+
1259
+
1260
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
1261
+ class BacformerPooler(nn.Module):
1262
+ """Pooler for Bacformer model."""
1263
+
1264
+ def __init__(self, config: BacformerConfig):
1265
+ super().__init__()
1266
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1267
+ self.activation = nn.Tanh()
1268
+
1269
+ def forward(self, hidden_states: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor:
1270
+ """Forward method for the pooler."""
1271
+ # We "pool" the model by taking the mean of non-padding tokens
1272
+ padding_mask = padding_mask.to(hidden_states.device) if padding_mask is not None else None
1273
+ if padding_mask is not None:
1274
+ mean_hidden_states = torch.einsum("ijk,ij->ik", hidden_states, padding_mask) / padding_mask.sum(
1275
+ 1
1276
+ ).unsqueeze(1)
1277
+ else:
1278
+ mean_hidden_states = hidden_states.mean(dim=1)
1279
+ pooled_output = self.dense(mean_hidden_states)
1280
+ pooled_output = self.activation(pooled_output)
1281
+ return pooled_output
1282
+
1283
+
1284
+ class BacformerGMHead(nn.Module):
1285
+ """Bacformer Head for genomic modeling."""
1286
+
1287
+ def __init__(self, config):
1288
+ super().__init__()
1289
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1290
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1291
+
1292
+ # add 1 to the condfig.protein_clusters_vocab_size to account for the end token
1293
+ self.decoder = nn.Linear(config.hidden_size, config.protein_clusters_vocab_size + 1, bias=False)
1294
+ self.bias = nn.Parameter(torch.zeros(config.protein_clusters_vocab_size + 1))
1295
+
1296
+ def forward(self, features, **kwargs):
1297
+ """Forward method for the head."""
1298
+ x = self.dense(features)
1299
+ x = gelu(x)
1300
+ x = self.layer_norm(x)
1301
+
1302
+ # project back to nr of labels with bias
1303
+ x = self.decoder(x) + self.bias
1304
+ return x
1305
+
1306
+
1307
+ class BacformerGenomeClassificationHead(nn.Module):
1308
+ """Head for genome-level classification tasks."""
1309
+
1310
+ def __init__(self, config: BacformerConfig):
1311
+ super().__init__()
1312
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1313
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1314
+
1315
+ def forward(self, features: torch.Tensor, padding_mask: torch.Tensor, **kwargs):
1316
+ """Forward method for the head."""
1317
+ if padding_mask is not None:
1318
+ x = torch.einsum("ijk,ij->ik", features, padding_mask) / padding_mask.sum(1).unsqueeze(1)
1319
+ else:
1320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1321
+ x = self.dropout(x)
1322
+ x = self.out_proj(x)
1323
+ return x
1324
+
1325
+
1326
+ class BacformerProteinProteinInteractionHead(nn.Module):
1327
+ """Head for protein-protein interaction task at a genome level."""
1328
+
1329
+ def __init__(self, in_features: int, prot_emb_idx: int = 4, bias: bool = True):
1330
+ super().__init__()
1331
+ self.in_features = in_features
1332
+ self.prot_emb_idx = prot_emb_idx
1333
+ self.dropout = nn.Dropout(0.2)
1334
+ self.linear = nn.Linear(in_features, 1, bias=bias)
1335
+
1336
+ def forward(
1337
+ self, hidden_states: torch.Tensor
1338
+ ) -> torch.Tensor: # special_tokens_mask: torch.Tensor, attentions: torch.Tensor):
1339
+ """Forward method for the head."""
1340
+ return self.linear(self.dropout(hidden_states)).squeeze(-1)
utils_bacformer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import cross_entropy, softmax
3
+
4
+ from .configuration_bacformer import SPECIAL_TOKENS_DICT
5
+
6
+
7
+ def compute_contrastive_loss(
8
+ protein_embeddings: torch.Tensor,
9
+ last_hidden_state: torch.Tensor,
10
+ special_tokens_mask: torch.Tensor,
11
+ ) -> torch.Tensor:
12
+ """Compute contrastive loss between protein embeddings and masked items."""
13
+ # keep protein embeddings and masked items
14
+ # ensure the batch size is 1, the model currently does not work with batch size > 1
15
+ assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1
16
+
17
+ # subset to mask and protein embedding tokens
18
+ special_tokens_mask = special_tokens_mask.squeeze(0)
19
+ mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | (
20
+ special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"]
21
+ )
22
+ protein_embeddings = protein_embeddings.squeeze(0)[mask]
23
+ last_hidden_state = last_hidden_state.squeeze(0)[mask]
24
+
25
+ # Normalize embeddings
26
+ last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True)
27
+ protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True)
28
+
29
+ # Compute similarity matrix and loss as before
30
+ similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T)
31
+
32
+ n_prots = protein_embeddings.shape[0]
33
+ labels = torch.arange(n_prots).to(protein_embeddings.device)
34
+
35
+ # Compute the loss
36
+ loss = cross_entropy(similarity_matrix, labels)
37
+ return loss
38
+
39
+
40
+ def top_k_filtering(logits: torch.Tensor, top_k: int = 50):
41
+ """
42
+ Keep only top_k logits and set the rest to -inf.
43
+
44
+ Args:
45
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
46
+ top_k (int): The number of highest probability logits to keep.
47
+
48
+ Returns
49
+ -------
50
+ torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf.
51
+ """
52
+ if top_k <= 0:
53
+ return logits
54
+
55
+ # Find top_k values
56
+ top_k = min(top_k, logits.size(-1))
57
+ vals, idx = torch.topk(logits, top_k, dim=-1)
58
+ # Get the smallest logit in the top_k
59
+ min_vals = vals[:, -1].unsqueeze(-1)
60
+ # Mask all logits that are < this min value
61
+ mask = logits < min_vals
62
+ logits[mask] = float("-inf")
63
+ return logits
64
+
65
+
66
+ def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9):
67
+ """
68
+ Keep the smallest set of logits whose cumulative probability >= top_p.
69
+
70
+ Args:
71
+ logits (torch.Tensor): Logits of shape (batch_size, vocab_size).
72
+ top_p (float): Cumulative probability threshold.
73
+
74
+ Returns
75
+ -------
76
+ torch.Tensor: Filtered logits where only tokens within the top_p cumulative
77
+ probability mass are kept; the rest are set to -inf.
78
+ """
79
+ if top_p >= 1.0:
80
+ return logits
81
+
82
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
83
+ cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
84
+
85
+ # Identify where cumulative probability exceeds top_p
86
+ sorted_indices_to_remove = cumulative_probs > top_p
87
+ # Shift the mask to ensure we always keep at least one token
88
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
89
+ sorted_indices_to_remove[..., 0] = False
90
+
91
+ # Scatter to replicate the mask in the original ordering
92
+ for i in range(logits.size(0)):
93
+ remove_indices = sorted_indices[i, sorted_indices_to_remove[i]]
94
+ logits[i, remove_indices] = float("-inf")
95
+
96
+ return logits
97
+
98
+
99
+ def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int):
100
+ """Helper function to reshape attn_mask to 3D from 2D"""
101
+ assert (
102
+ len(attn_mask.shape) == 2
103
+ ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}"
104
+
105
+ bs, seq_len = attn_mask.shape
106
+ attn_mask = attn_mask.view(bs, 1, 1, seq_len)
107
+ attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1)
108
+ attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len)
109
+ return attn_mask