Commit
·
c4a5068
1
Parent(s):
ef7baa6
update readme3
Browse files- README.md +64 -22
- examples.pt +0 -3
README.md
CHANGED
@@ -4,11 +4,15 @@ license: other
|
|
4 |
|
5 |
# AIDO.RAGProtein-16B
|
6 |
|
7 |
-
AIDO.RAGProtein-16B is a multimodal protein language model that integrates MSA and structural data
|
|
|
|
|
|
|
|
|
8 |
|
9 |
## Model Architecture Details
|
10 |
|
11 |
-
AIDO.RAGProtein-16B
|
12 |
|
13 |
<center><img src="proteinmoe_architecture.png" alt="An Overview of AIDO.Protein" style="width:70%; height:auto;" /></center>
|
14 |
|
@@ -27,37 +31,55 @@ More architecture details are shown below:
|
|
27 |
|
28 |
## Pre-training of AIDO.RAGProtein-16B
|
29 |
|
30 |
-
Here we briefly introduce the details of pre-training of AIDO.RAGProtein-16B. Mainly divided into three stages: (1) 1D -> 2D finetuning; (2) UniRef50/Uniclust30 MSA finetuning; (3) AlphaFold Database MSA & Structure tokens finetuning.
|
31 |
|
32 |
### Data
|
33 |
|
34 |
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
35 |
|
36 |
-
**AlphaFold Database MSA & Structure dataset**: We downloaded all
|
37 |
|
38 |
### Training Details
|
39 |
|
40 |
Model training is divided into three stages:
|
41 |
|
42 |
-
#### (1) 1D -> 2D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
|
46 |
-
|
47 |
|
48 |
-
|
49 |
|
50 |
-
|
51 |
|
52 |
-
|
53 |
|
54 |
-
|
55 |
|
56 |
-
|
57 |
|
58 |
-
|
59 |
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
| Hyper-params | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
|
63 |
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
|
@@ -76,17 +98,17 @@ We encode protein sequence with single amino acid resolution with 44 vocabularie
|
|
76 |
|
77 |
## Results
|
78 |
|
79 |
-
### Supervised
|
80 |
|
81 |
<center><img src="supervised_tasks.png" alt="supervised_tasks" style="width:90%; height:auto;" /></center>
|
82 |
|
83 |
-
### Supervised DMS
|
84 |
|
85 |
<center><img src="supervised_dms.png" alt="supervised_dms" style="width:90%; height:auto;" /></center>
|
86 |
|
87 |
## How to Use
|
88 |
|
89 |
-
### Build
|
90 |
|
91 |
For more information, visit: [Model Generator](https://github.com/genbio-ai/modelgenerator)
|
92 |
|
@@ -95,7 +117,7 @@ mgen fit --model SequenceClassification --model.backbone aido_protein_rag_16b --
|
|
95 |
mgen test --model SequenceClassification --model.backbone aido_protein_rag_16b --data SequenceClassificationDataModule --data.path <hf_or_local_path_to_your_dataset>
|
96 |
```
|
97 |
|
98 |
-
###
|
99 |
|
100 |
#### Embedding
|
101 |
|
@@ -104,7 +126,12 @@ import torch
|
|
104 |
from modelgenerator.tasks import Embed
|
105 |
model = Embed.from_config({"model.backbone": "aido_protein_rag_16b"}).eval()
|
106 |
model.backbone.max_length = 12800
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
108 |
transformed_batch = model.transform(data)
|
109 |
with torch.no_grad():
|
110 |
embedding = model(transformed_batch)
|
@@ -119,7 +146,12 @@ import torch
|
|
119 |
from modelgenerator.tasks import SequenceClassification
|
120 |
model = SequenceClassification.from_config({"model.backbone": "aido_protein_rag_16b", "model.n_classes": 2}).eval()
|
121 |
model.backbone.max_length = 12800
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
123 |
transformed_batch = model.transform(data)
|
124 |
with torch.no_grad():
|
125 |
logits = model(transformed_batch)
|
@@ -135,7 +167,12 @@ import torch
|
|
135 |
from modelgenerator.tasks import TokenClassification
|
136 |
model = TokenClassification.from_config({"model.backbone": "aido_protein_rag_16b", "model.n_classes": 3}).eval()
|
137 |
model.backbone.max_length = 12800
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
139 |
transformed_batch = model.transform(data)
|
140 |
with torch.no_grad():
|
141 |
logits = model(transformed_batch)
|
@@ -150,7 +187,12 @@ print(torch.argmax(logits, dim=-1))
|
|
150 |
from modelgenerator.tasks import SequenceRegression
|
151 |
model = SequenceRegression.from_config({"model.backbone": "aido_protein_rag_16b"}).eval()
|
152 |
model.backbone.max_length = 12800
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
154 |
transformed_batch = model.transform(data)
|
155 |
with torch.no_grad():
|
156 |
logits = model(transformed_batch)
|
|
|
4 |
|
5 |
# AIDO.RAGProtein-16B
|
6 |
|
7 |
+
AIDO.RAGProtein-16B is a multimodal protein language model that integrates Multiple Sequence Alignment (MSA) and structural data, building upon the [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.Protein-16B) foundation. The training process comprises three main stages:
|
8 |
+
|
9 |
+
1. 2D RoPE encoding fine-tuning
|
10 |
+
2. Initial training on 100 billion tokens from UniRef50/UniClust30 MSA data
|
11 |
+
3. Subsequent training on 80 billion tokens from AlphaFold Database MSA and structural data
|
12 |
|
13 |
## Model Architecture Details
|
14 |
|
15 |
+
AIDO.RAGProtein-16B employs a transformer encoder-only architecture featuring sparse Mixture-of-Experts (MoE) layers that replace dense MLP layers in each transformer block. Utilizing single amino acid tokenization and optimized through masked language modeling (MLM), the model activates 2 experts per token via top-2 routing mechanisms.
|
16 |
|
17 |
<center><img src="proteinmoe_architecture.png" alt="An Overview of AIDO.Protein" style="width:70%; height:auto;" /></center>
|
18 |
|
|
|
31 |
|
32 |
## Pre-training of AIDO.RAGProtein-16B
|
33 |
|
34 |
+
Here we briefly introduce the details of pre-training of AIDO.RAGProtein-16B. Mainly divided into three stages: (1) 1D -> 2D RoPE encoding finetuning; (2) UniRef50/Uniclust30 MSA finetuning; (3) AlphaFold Database MSA & Structure tokens finetuning.
|
35 |
|
36 |
### Data
|
37 |
|
38 |
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
39 |
|
40 |
+
**AlphaFold Database MSA & Structure dataset**: We downloaded all structural data from the AlphaFold Database and kept only those where more than 40% of amino acids had a pLDDT score > 70. The remaining sequences were clustered using `mmseqs` (`seq id=0.5`), and one representative per cluster was retained, resulting in 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain structure tokens and embeddings. [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) was used to obtain the corresponding MSA.
|
41 |
|
42 |
### Training Details
|
43 |
|
44 |
Model training is divided into three stages:
|
45 |
|
46 |
+
#### (1) 1D -> 2D RoPE Encoding Fine-tuning
|
47 |
+
|
48 |
+
Same training data as [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.Protein-16B), but with [2D rotary position embedding](https://arxiv.org/abs/2406.05347) for token encoding.
|
49 |
+
|
50 |
+
#### (2) UniRef50/UniClust30 MSA Fine-tuning
|
51 |
+
|
52 |
+
The model from Stage 1 is further fine-tuned on the UniRef50/Uniclust30 MSA dataset. See the [AIDO.Protein-RAG-3B paper](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) for more.
|
53 |
+
|
54 |
+
#### (3) AlphaFold Database MSA & Structure Fine-tuning
|
55 |
+
|
56 |
+
We fine-tuned the model with concatenated query and homologous sequences. Structure embeddings (dim = 384) are linearly mapped to 2304 and added to the query token embeddings.
|
57 |
+
|
58 |
+
##### Sequence Masking
|
59 |
|
60 |
+
* Randomly sample `0.05 × L` span positions from a query of length `L`. Span lengths follow a geometric distribution (`p=0.2`), capped at length 10. On average, ~15% of query tokens are masked.
|
61 |
|
62 |
+
* When a residue is selected, its aligned residues across all sequences (MSA column) are also masked.
|
63 |
|
64 |
+
* For masked MSA columns: 80% are replaced with `<MASK>`, 10% with random amino acids, and 10% left unchanged.
|
65 |
|
66 |
+
##### Structure Masking
|
67 |
|
68 |
+
* In 20% of cases, structure embeddings are replaced with 0.
|
69 |
|
70 |
+
* In 80% of cases, a number of amino acids is sampled using the BetaLinear30 distribution and corresponding embeddings are zeroed. (BetaLinear30 = 20% Uniform(0,1) + 80% Beta(3,9)).
|
71 |
|
72 |
+
##### Positional Embedding
|
73 |
|
74 |
+
We use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to help the model distinguish token chain identities and residue indices. See AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
75 |
|
76 |
+
##### Loss Function
|
77 |
+
|
78 |
+
Total loss is a weighted sum of sequence loss (weight 1.0) and structure loss (weight 0.01).
|
79 |
+
|
80 |
+
* **Sequence loss**: CrossEntropy loss for masked token prediction.
|
81 |
+
|
82 |
+
* **Structure loss**: CrossEntropy loss for masked structure token prediction.
|
83 |
|
84 |
| Hyper-params | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
|
85 |
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
|
|
|
98 |
|
99 |
## Results
|
100 |
|
101 |
+
### Supervised Downstream Tasks
|
102 |
|
103 |
<center><img src="supervised_tasks.png" alt="supervised_tasks" style="width:90%; height:auto;" /></center>
|
104 |
|
105 |
+
### Supervised DMS Fitness Score Prediction (25 Samples)
|
106 |
|
107 |
<center><img src="supervised_dms.png" alt="supervised_dms" style="width:90%; height:auto;" /></center>
|
108 |
|
109 |
## How to Use
|
110 |
|
111 |
+
### Build Downstream Models Using ModelGenerator
|
112 |
|
113 |
For more information, visit: [Model Generator](https://github.com/genbio-ai/modelgenerator)
|
114 |
|
|
|
117 |
mgen test --model SequenceClassification --model.backbone aido_protein_rag_16b --data SequenceClassificationDataModule --data.path <hf_or_local_path_to_your_dataset>
|
118 |
```
|
119 |
|
120 |
+
### Use Directly in Python
|
121 |
|
122 |
#### Embedding
|
123 |
|
|
|
126 |
from modelgenerator.tasks import Embed
|
127 |
model = Embed.from_config({"model.backbone": "aido_protein_rag_16b"}).eval()
|
128 |
model.backbone.max_length = 12800
|
129 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
130 |
+
data = {
|
131 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
132 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
133 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
134 |
+
}
|
135 |
transformed_batch = model.transform(data)
|
136 |
with torch.no_grad():
|
137 |
embedding = model(transformed_batch)
|
|
|
146 |
from modelgenerator.tasks import SequenceClassification
|
147 |
model = SequenceClassification.from_config({"model.backbone": "aido_protein_rag_16b", "model.n_classes": 2}).eval()
|
148 |
model.backbone.max_length = 12800
|
149 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
150 |
+
data = {
|
151 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
152 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
153 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
154 |
+
}
|
155 |
transformed_batch = model.transform(data)
|
156 |
with torch.no_grad():
|
157 |
logits = model(transformed_batch)
|
|
|
167 |
from modelgenerator.tasks import TokenClassification
|
168 |
model = TokenClassification.from_config({"model.backbone": "aido_protein_rag_16b", "model.n_classes": 3}).eval()
|
169 |
model.backbone.max_length = 12800
|
170 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
171 |
+
data = {
|
172 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
173 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
174 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
175 |
+
}
|
176 |
transformed_batch = model.transform(data)
|
177 |
with torch.no_grad():
|
178 |
logits = model(transformed_batch)
|
|
|
187 |
from modelgenerator.tasks import SequenceRegression
|
188 |
model = SequenceRegression.from_config({"model.backbone": "aido_protein_rag_16b"}).eval()
|
189 |
model.backbone.max_length = 12800
|
190 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
191 |
+
data = {
|
192 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
193 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
194 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
195 |
+
}
|
196 |
transformed_batch = model.transform(data)
|
197 |
with torch.no_grad():
|
198 |
logits = model(transformed_batch)
|
examples.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e60409786b001317d5150ab440521fa1f9a6a90ca18f4666e27740b4e6a75aa5
|
3 |
-
size 5485326
|
|
|
|
|
|
|
|