Commit
·
4ae63c9
1
Parent(s):
a9ee239
update readme
Browse files- README.md +23 -12
- examples.pt +3 -0
README.md
CHANGED
@@ -31,7 +31,7 @@ Here we briefly introduce the details of pre-training of AIDO.RAGProtein-16B. Ma
|
|
31 |
|
32 |
### Data
|
33 |
|
34 |
-
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25
|
35 |
|
36 |
**AlphaFold Database MSA & Structure dataset**: We downloaded all the structural data from the AlphaFold Database and only kept the structures where the amino acid ratio of pLDDT>70 was greater than 40%. Then we used `mmseqs` to cluster the remaining sequences with `seq id=0.5`, and retained a representative sequence for each class. Final we get 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain the corresponding structure tokens and structure embedding. And used [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) to obtain the MSA corresponding to the sequence.
|
37 |
|
@@ -45,34 +45,45 @@ Same training data with [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO
|
|
45 |
|
46 |
#### (2) UniRef50/Uniclust30 MSA finetuning
|
47 |
|
48 |
-
We used UniRef50/Uniclust30 MSA dataset to finetune the model from stage (1). Refer
|
49 |
|
50 |
#### (3) AFDB MSA & Structure tokens finetuning:
|
51 |
|
52 |
We fine-tuned a pretrained masked language model using MSA data by concatenating the query sequence with homologous sequences. The input structure embedding (hidden dimension 384) is linearly mapped to 2304 and then added to the corresponding embedding of the query sequence tokens.
|
53 |
|
54 |
-
**
|
55 |
|
56 |
-
**
|
57 |
|
58 |
-
**Positional embedding**: To help the model distinguish which tokens are from the same chain and which tokens have the same residue index, we use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to encode the tokens.
|
59 |
|
60 |
-
**Loss**: The loss function consists of a sequence loss function and a structure loss function (weights are 1.0 and 0.01 respectively). The sequence loss function is the CrossEntropy
|
61 |
|
62 |
| Hyper-params | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
|
63 |
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
|
|
|
64 |
| Data | ColabFoldDB, UniRef | HHblits_MSA, Retriever_MSA | AFDB MSA & Structure tokens |
|
65 |
| Global Batch Size | 512 | 256 | 256 |
|
66 |
| Sequence length | 2048 | 12800 | 12800 |
|
67 |
| Per Device Micro Batch Size | 1 | 1 | 1 |
|
68 |
| Precision | Mixed FP32-FP16 | Mixed FP32-FP16 | Mixed FP32-FP16 |
|
69 |
-
| LR
|
70 |
-
| Num Tokens
|
71 |
|
72 |
### Tokenization
|
73 |
|
74 |
We encode protein sequence with single amino acid resolution with 44 vocabularies, where 24 tokens represent amino acid types and 20 are special tokens. Sequences were also suffixed with a `[SEP]` token as hooks for downstream tasks.
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
## How to Use
|
77 |
|
78 |
### Build any downstream models from this backbone with ModelGenerator
|
@@ -93,7 +104,7 @@ import torch
|
|
93 |
from modelgenerator.tasks import Embed
|
94 |
model = Embed.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
|
95 |
model.backbone.max_length = 12800
|
96 |
-
data = torch.load("
|
97 |
transformed_batch = model.transform(data)
|
98 |
with torch.no_grad():
|
99 |
embedding = model(transformed_batch)
|
@@ -108,7 +119,7 @@ import torch
|
|
108 |
from modelgenerator.tasks import SequenceClassification
|
109 |
model = SequenceClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 2}).eval()
|
110 |
model.backbone.max_length = 12800
|
111 |
-
data = torch.load("
|
112 |
transformed_batch = model.transform(data)
|
113 |
with torch.no_grad():
|
114 |
logits = model(transformed_batch)
|
@@ -124,7 +135,7 @@ import torch
|
|
124 |
from modelgenerator.tasks import TokenClassification
|
125 |
model = TokenClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 3}).eval()
|
126 |
model.backbone.max_length = 12800
|
127 |
-
data = torch.load("
|
128 |
transformed_batch = model.transform(data)
|
129 |
with torch.no_grad():
|
130 |
logits = model(transformed_batch)
|
@@ -139,7 +150,7 @@ print(torch.argmax(logits, dim=-1))
|
|
139 |
from modelgenerator.tasks import SequenceRegression
|
140 |
model = SequenceRegression.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
|
141 |
model.backbone.max_length = 12800
|
142 |
-
data = torch.load("
|
143 |
transformed_batch = model.transform(data)
|
144 |
with torch.no_grad():
|
145 |
logits = model(transformed_batch)
|
|
|
31 |
|
32 |
### Data
|
33 |
|
34 |
+
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
35 |
|
36 |
**AlphaFold Database MSA & Structure dataset**: We downloaded all the structural data from the AlphaFold Database and only kept the structures where the amino acid ratio of pLDDT>70 was greater than 40%. Then we used `mmseqs` to cluster the remaining sequences with `seq id=0.5`, and retained a representative sequence for each class. Final we get 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain the corresponding structure tokens and structure embedding. And used [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) to obtain the MSA corresponding to the sequence.
|
37 |
|
|
|
45 |
|
46 |
#### (2) UniRef50/Uniclust30 MSA finetuning
|
47 |
|
48 |
+
We used UniRef50/Uniclust30 MSA dataset to finetune the model from stage (1). Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
49 |
|
50 |
#### (3) AFDB MSA & Structure tokens finetuning:
|
51 |
|
52 |
We fine-tuned a pretrained masked language model using MSA data by concatenating the query sequence with homologous sequences. The input structure embedding (hidden dimension 384) is linearly mapped to 2304 and then added to the corresponding embedding of the query sequence tokens.
|
53 |
|
54 |
+
**Ssequence masking**: We introduced several modifications to the standard BERT masking strategy: (1) We randomly sampled `0.05×L` span positions from a query sequence of length `L`, with span lengths following a geometric distribution (`p=0.2`), and capped the maximum length at 10. Our experiments revealed that this settings lead to an average of 15% of the query tokens were masked. (2) To prevent information leakage, when a residue was selected, all residues at the same index across all sequences (the column of the MSA matrix) were also masked. (3) When a column of MSA was selected for masking, the entire column was replaced with the `<MASK>` token in 80% of cases, with random amino acids in 10% of cases, and remained unchanged in the remaining 10% of cases.
|
55 |
|
56 |
+
**Structure masking**: In 20% of the cases, we randomly replaced the structure embedding with 0; in 80% of the cases, we randomly sampled the number of amino acids using the BetaLinear30 distribution and replaced the structure embedding with 0. The BetaLinear30 distribution is defined as a combination of 20% of the Uniform(0, 1) distribution and 80% of the Beta(3, 9) distribution.
|
57 |
|
58 |
+
**Positional embedding**: To help the model distinguish which tokens are from the same chain and which tokens have the same residue index, we use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to encode the tokens. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
59 |
|
60 |
+
**Loss**: The loss function consists of a sequence loss function and a structure loss function (weights are 1.0 and 0.01 respectively). The sequence loss function is the CrossEntropy of recovering the masked sequence tokens, and the structure loss function is the CrossEntropy of predicting the masked structure tokens.
|
61 |
|
62 |
| Hyper-params | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
|
63 |
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
|
64 |
+
| Initialized parameters | AIDO.Protein-16B | Stage (1) | Stage (2) |
|
65 |
| Data | ColabFoldDB, UniRef | HHblits_MSA, Retriever_MSA | AFDB MSA & Structure tokens |
|
66 |
| Global Batch Size | 512 | 256 | 256 |
|
67 |
| Sequence length | 2048 | 12800 | 12800 |
|
68 |
| Per Device Micro Batch Size | 1 | 1 | 1 |
|
69 |
| Precision | Mixed FP32-FP16 | Mixed FP32-FP16 | Mixed FP32-FP16 |
|
70 |
+
| LR | [5e-6,5e-5] | [1e-6, 1e-5] | 1e-5 |
|
71 |
+
| Num Tokens | 10 billion | 100 billion | 80 billion |
|
72 |
|
73 |
### Tokenization
|
74 |
|
75 |
We encode protein sequence with single amino acid resolution with 44 vocabularies, where 24 tokens represent amino acid types and 20 are special tokens. Sequences were also suffixed with a `[SEP]` token as hooks for downstream tasks.
|
76 |
|
77 |
+
## Results
|
78 |
+
|
79 |
+
### Supervised downstream tasks
|
80 |
+
|
81 |
+
<center><img src="supervised_tasks.png" alt="supervised_tasks" style="width:90%; height:auto;" /></center>
|
82 |
+
|
83 |
+
### Supervised DMS fitness score prediction of 25 samples
|
84 |
+
|
85 |
+
<center><img src="supervised_dms.png" alt="supervised_dms" style="width:90%; height:auto;" /></center>
|
86 |
+
|
87 |
## How to Use
|
88 |
|
89 |
### Build any downstream models from this backbone with ModelGenerator
|
|
|
104 |
from modelgenerator.tasks import Embed
|
105 |
model = Embed.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
|
106 |
model.backbone.max_length = 12800
|
107 |
+
data = torch.load("examples.pt", 'cpu')[0]
|
108 |
transformed_batch = model.transform(data)
|
109 |
with torch.no_grad():
|
110 |
embedding = model(transformed_batch)
|
|
|
119 |
from modelgenerator.tasks import SequenceClassification
|
120 |
model = SequenceClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 2}).eval()
|
121 |
model.backbone.max_length = 12800
|
122 |
+
data = torch.load("examples.pt", 'cpu')[0]
|
123 |
transformed_batch = model.transform(data)
|
124 |
with torch.no_grad():
|
125 |
logits = model(transformed_batch)
|
|
|
135 |
from modelgenerator.tasks import TokenClassification
|
136 |
model = TokenClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 3}).eval()
|
137 |
model.backbone.max_length = 12800
|
138 |
+
data = torch.load("examples.pt", 'cpu')[0]
|
139 |
transformed_batch = model.transform(data)
|
140 |
with torch.no_grad():
|
141 |
logits = model(transformed_batch)
|
|
|
150 |
from modelgenerator.tasks import SequenceRegression
|
151 |
model = SequenceRegression.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
|
152 |
model.backbone.max_length = 12800
|
153 |
+
data = torch.load("examples.pt", 'cpu')[0]
|
154 |
transformed_batch = model.transform(data)
|
155 |
with torch.no_grad():
|
156 |
logits = model(transformed_batch)
|
examples.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e60409786b001317d5150ab440521fa1f9a6a90ca18f4666e27740b4e6a75aa5
|
3 |
+
size 5485326
|