Commit
·
9e3dad5
1
Parent(s):
f6bacea
update readme3
Browse files
README.md
CHANGED
@@ -4,13 +4,13 @@ license: other
|
|
4 |
|
5 |
# AIDO.Protein-RAG-3B
|
6 |
|
7 |
-
AIDO.Protein-RAG-3B (AIDO.RAGPLM) is a pretrained
|
8 |
|
9 |
-
AIDO.Protein-RAG-3B
|
10 |
|
11 |
-
## Model Architecture
|
12 |
|
13 |
-
AIDO.Protein-RAG-3B
|
14 |
|
15 |
<center><img src="architecture.png" alt="An Overview of AIDO.Protein" style="width:90%; height:auto;" /></center>
|
16 |
|
@@ -24,9 +24,9 @@ More architecture details are shown below:
|
|
24 |
| FFN Hidden Size | 6832 |
|
25 |
| Context Length | 12.8K |
|
26 |
|
27 |
-
## Pre-training
|
28 |
|
29 |
-
### Data
|
30 |
|
31 |
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25
|
32 |
|
@@ -59,7 +59,7 @@ AIDO.Protein-RAG-3B surpasses single-sequence protein language models in perplex
|
|
59 |
|
60 |
<center><img src="unsupervised_contact_prediction.png" alt="unsupervised_contact_prediction" style="width:90%; height:auto;" /></center>
|
61 |
|
62 |
-
### Supervised
|
63 |
|
64 |
<center><img src="supervised_tasks.png" alt="supervised_tasks" style="width:90%; height:auto;" /></center>
|
65 |
|
@@ -69,7 +69,7 @@ AIDO.Protein-RAG-3B surpasses single-sequence protein language models in perplex
|
|
69 |
|
70 |
## How to Use
|
71 |
|
72 |
-
### Build
|
73 |
|
74 |
For more information, visit: [Model Generator](https://github.com/genbio-ai/modelgenerator)
|
75 |
|
@@ -78,7 +78,7 @@ mgen fit --model SequenceClassification --model.backbone aido_protein_rag_3b --d
|
|
78 |
mgen test --model SequenceClassification --model.backbone aido_protein_rag_3b --data SequenceClassificationDataModule --data.path <hf_or_local_path_to_your_dataset>
|
79 |
```
|
80 |
|
81 |
-
###
|
82 |
|
83 |
#### Embedding
|
84 |
|
@@ -87,7 +87,12 @@ import torch
|
|
87 |
from modelgenerator.tasks import Embed
|
88 |
model = Embed.from_config({"model.backbone": "aido_protein_rag_3b"}).eval()
|
89 |
model.backbone.max_length = 12800
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
transformed_batch = model.transform(data)
|
92 |
with torch.no_grad():
|
93 |
embedding = model(transformed_batch)
|
@@ -102,7 +107,12 @@ import torch
|
|
102 |
from modelgenerator.tasks import SequenceClassification
|
103 |
model = SequenceClassification.from_config({"model.backbone": "aido_protein_rag_3b", "model.n_classes": 2}).eval()
|
104 |
model.backbone.max_length = 12800
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
106 |
transformed_batch = model.transform(data)
|
107 |
with torch.no_grad():
|
108 |
logits = model(transformed_batch)
|
@@ -118,7 +128,12 @@ import torch
|
|
118 |
from modelgenerator.tasks import TokenClassification
|
119 |
model = TokenClassification.from_config({"model.backbone": "aido_protein_rag_3b", "model.n_classes": 3}).eval()
|
120 |
model.backbone.max_length = 12800
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
122 |
transformed_batch = model.transform(data)
|
123 |
with torch.no_grad():
|
124 |
logits = model(transformed_batch)
|
@@ -127,13 +142,19 @@ print(logits)
|
|
127 |
print(torch.argmax(logits, dim=-1))
|
128 |
```
|
129 |
|
130 |
-
#### Regression
|
131 |
|
132 |
```python
|
|
|
133 |
from modelgenerator.tasks import SequenceRegression
|
134 |
-
model = SequenceRegression.from_config({"model.backbone": "
|
135 |
model.backbone.max_length = 12800
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
137 |
transformed_batch = model.transform(data)
|
138 |
with torch.no_grad():
|
139 |
logits = model(transformed_batch)
|
|
|
4 |
|
5 |
# AIDO.Protein-RAG-3B
|
6 |
|
7 |
+
AIDO.Protein-RAG-3B (AIDO.RAGPLM) is a pretrained Retrieval-Augmented protein language model within an [AI-driven Digital Organism](https://arxiv.org/abs/2412.06993) framework. This model, along with [AIDO.RAGFold](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1), integrates pretrained protein language models with retrieved Multiple Sequence Alignments (MSA), enabling the incorporation of co-evolutionary information for structure prediction while compensating for limited MSA data through large-scale pretraining.
|
8 |
|
9 |
+
AIDO.Protein-RAG-3B outperforms single-sequence protein language models in perplexity, contact prediction, and fitness prediction. When used as a feature extractor for structure prediction in [AIDO.RAGFold](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1), it achieves TM-scores comparable to AlphaFold2 with sufficient MSA data (8x faster runtime), and significantly surpasses AlphaFold2 in MSA-limited scenarios (∆TM-score=0.379, 0.116, and 0.059 for 0, 5, and 10 input sequences respectively).
|
10 |
|
11 |
+
## Model Architecture
|
12 |
|
13 |
+
AIDO.Protein-RAG-3B employs a transformer encoder-only architecture with dense MLP layers in each block (Panel **c** below). The model uses single amino acid tokenization and is optimized via masked language modeling (MLM).
|
14 |
|
15 |
<center><img src="architecture.png" alt="An Overview of AIDO.Protein" style="width:90%; height:auto;" /></center>
|
16 |
|
|
|
24 |
| FFN Hidden Size | 6832 |
|
25 |
| Context Length | 12.8K |
|
26 |
|
27 |
+
## Pre-training
|
28 |
|
29 |
+
### Data Preparation
|
30 |
|
31 |
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25
|
32 |
|
|
|
59 |
|
60 |
<center><img src="unsupervised_contact_prediction.png" alt="unsupervised_contact_prediction" style="width:90%; height:auto;" /></center>
|
61 |
|
62 |
+
### Supervised downstream tasks
|
63 |
|
64 |
<center><img src="supervised_tasks.png" alt="supervised_tasks" style="width:90%; height:auto;" /></center>
|
65 |
|
|
|
69 |
|
70 |
## How to Use
|
71 |
|
72 |
+
### Build Downstream Models Using ModelGenerator
|
73 |
|
74 |
For more information, visit: [Model Generator](https://github.com/genbio-ai/modelgenerator)
|
75 |
|
|
|
78 |
mgen test --model SequenceClassification --model.backbone aido_protein_rag_3b --data SequenceClassificationDataModule --data.path <hf_or_local_path_to_your_dataset>
|
79 |
```
|
80 |
|
81 |
+
### Use Directly in Python
|
82 |
|
83 |
#### Embedding
|
84 |
|
|
|
87 |
from modelgenerator.tasks import Embed
|
88 |
model = Embed.from_config({"model.backbone": "aido_protein_rag_3b"}).eval()
|
89 |
model.backbone.max_length = 12800
|
90 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
91 |
+
data = {
|
92 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
93 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
94 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
95 |
+
}
|
96 |
transformed_batch = model.transform(data)
|
97 |
with torch.no_grad():
|
98 |
embedding = model(transformed_batch)
|
|
|
107 |
from modelgenerator.tasks import SequenceClassification
|
108 |
model = SequenceClassification.from_config({"model.backbone": "aido_protein_rag_3b", "model.n_classes": 2}).eval()
|
109 |
model.backbone.max_length = 12800
|
110 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
111 |
+
data = {
|
112 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
113 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
114 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
115 |
+
}
|
116 |
transformed_batch = model.transform(data)
|
117 |
with torch.no_grad():
|
118 |
logits = model(transformed_batch)
|
|
|
128 |
from modelgenerator.tasks import TokenClassification
|
129 |
model = TokenClassification.from_config({"model.backbone": "aido_protein_rag_3b", "model.n_classes": 3}).eval()
|
130 |
model.backbone.max_length = 12800
|
131 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
132 |
+
data = {
|
133 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
134 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
135 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
136 |
+
}
|
137 |
transformed_batch = model.transform(data)
|
138 |
with torch.no_grad():
|
139 |
logits = model(transformed_batch)
|
|
|
142 |
print(torch.argmax(logits, dim=-1))
|
143 |
```
|
144 |
|
145 |
+
#### Sequence Level Regression
|
146 |
|
147 |
```python
|
148 |
+
import torch
|
149 |
from modelgenerator.tasks import SequenceRegression
|
150 |
+
model = SequenceRegression.from_config({"model.backbone": "aido_protein_rag_3b"}).eval()
|
151 |
model.backbone.max_length = 12800
|
152 |
+
restypes = 'ARNDCQEGHILKMFPSTWYV'
|
153 |
+
data = {
|
154 |
+
'sequences': [''.join(random.choice(restypes) for _ in range(50))],
|
155 |
+
'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
|
156 |
+
'str_emb': np.random.normal(size=(1, 50, 384))
|
157 |
+
}
|
158 |
transformed_batch = model.transform(data)
|
159 |
with torch.no_grad():
|
160 |
logits = model(transformed_batch)
|