Pan6461188 commited on
Commit
4ae63c9
·
1 Parent(s): a9ee239

update readme

Browse files
Files changed (2) hide show
  1. README.md +23 -12
  2. examples.pt +3 -0
README.md CHANGED
@@ -31,7 +31,7 @@ Here we briefly introduce the details of pre-training of AIDO.RAGProtein-16B. Ma
31
 
32
  ### Data
33
 
34
- **UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25
35
 
36
  **AlphaFold Database MSA & Structure dataset**: We downloaded all the structural data from the AlphaFold Database and only kept the structures where the amino acid ratio of pLDDT>70 was greater than 40%. Then we used `mmseqs` to cluster the remaining sequences with `seq id=0.5`, and retained a representative sequence for each class. Final we get 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain the corresponding structure tokens and structure embedding. And used [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) to obtain the MSA corresponding to the sequence.
37
 
@@ -45,34 +45,45 @@ Same training data with [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO
45
 
46
  #### (2) UniRef50/Uniclust30 MSA finetuning
47
 
48
- We used UniRef50/Uniclust30 MSA dataset to finetune the model from stage (1). Refer [AIDO.RAGPLM](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) for more information.
49
 
50
  #### (3) AFDB MSA & Structure tokens finetuning:
51
 
52
  We fine-tuned a pretrained masked language model using MSA data by concatenating the query sequence with homologous sequences. The input structure embedding (hidden dimension 384) is linearly mapped to 2304 and then added to the corresponding embedding of the query sequence tokens.
53
 
54
- **Mask of sequences**: We introduced several modifications to the standard BERT masking strategy: (1) We randomly sampled `0.05×L` span positions from a query sequence of length `L`, with span lengths following a geometric distribution (`p=0.2`), and capped the maximum length at 10. Our experiments revealed that this settings lead to an average of 15% of the query tokens were masked. (2) To prevent information leakage, when a residue was selected, all residues at the same index across all sequences (the column of the MSA matrix) were also masked. (3) When a column of MSA was selected for masking, the entire column was replaced with the `<MASK>` token in 80% of cases, with random amino acids in 10% of cases, and remained unchanged in the remaining 10% of cases.
55
 
56
- **Mask of structure**: In 20% of the cases, we randomly replaced the structure embedding with 0; in 80% of the cases, we randomly sampled a certain number of amino acids using the BetaLinear30 distribution and masked their structure embedding. The BetaLinear30 distribution is defined as a combination of 20% of the [0, 1] uniform distribution and 80% of the Beta(3, 9) Beta distribution.
57
 
58
- **Positional embedding**: To help the model distinguish which tokens are from the same chain and which tokens have the same residue index, we use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to encode the tokens.
59
 
60
- **Loss**: The loss function consists of a sequence loss function and a structure loss function (weights are 1.0 and 0.01 respectively). The sequence loss function is the CrossEntropy function that recovers the masked sequence tokens, and the structure loss function is the CrossEntropy function that predicts each masked structure token.
61
 
62
  | Hyper-params | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
63
  | --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
 
64
  | Data | ColabFoldDB, UniRef | HHblits_MSA, Retriever_MSA | AFDB MSA & Structure tokens |
65
  | Global Batch Size | 512 | 256 | 256 |
66
  | Sequence length | 2048 | 12800 | 12800 |
67
  | Per Device Micro Batch Size | 1 | 1 | 1 |
68
  | Precision | Mixed FP32-FP16 | Mixed FP32-FP16 | Mixed FP32-FP16 |
69
- | LR | [5e-6,5e-5] | [1e-6, 1e-5] | 1e-5 |
70
- | Num Tokens | 10 billion | 100 billion | 80 billion |
71
 
72
  ### Tokenization
73
 
74
  We encode protein sequence with single amino acid resolution with 44 vocabularies, where 24 tokens represent amino acid types and 20 are special tokens. Sequences were also suffixed with a `[SEP]` token as hooks for downstream tasks.
75
 
 
 
 
 
 
 
 
 
 
 
76
  ## How to Use
77
 
78
  ### Build any downstream models from this backbone with ModelGenerator
@@ -93,7 +104,7 @@ import torch
93
  from modelgenerator.tasks import Embed
94
  model = Embed.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
95
  model.backbone.max_length = 12800
96
- data = torch.load("ModelGenerator/experiments/AIDO.RAGPLM/examples.pt", 'cpu')[0]
97
  transformed_batch = model.transform(data)
98
  with torch.no_grad():
99
  embedding = model(transformed_batch)
@@ -108,7 +119,7 @@ import torch
108
  from modelgenerator.tasks import SequenceClassification
109
  model = SequenceClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 2}).eval()
110
  model.backbone.max_length = 12800
111
- data = torch.load("ModelGenerator/experiments/AIDO.RAGPLM/examples.pt", 'cpu')[0]
112
  transformed_batch = model.transform(data)
113
  with torch.no_grad():
114
  logits = model(transformed_batch)
@@ -124,7 +135,7 @@ import torch
124
  from modelgenerator.tasks import TokenClassification
125
  model = TokenClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 3}).eval()
126
  model.backbone.max_length = 12800
127
- data = torch.load("ModelGenerator/experiments/AIDO.RAGPLM/examples.pt", 'cpu')[0]
128
  transformed_batch = model.transform(data)
129
  with torch.no_grad():
130
  logits = model(transformed_batch)
@@ -139,7 +150,7 @@ print(torch.argmax(logits, dim=-1))
139
  from modelgenerator.tasks import SequenceRegression
140
  model = SequenceRegression.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
141
  model.backbone.max_length = 12800
142
- data = torch.load("ModelGenerator/experiments/AIDO.RAGPLM/examples.pt", 'cpu')[0]
143
  transformed_batch = model.transform(data)
144
  with torch.no_grad():
145
  logits = model(transformed_batch)
 
31
 
32
  ### Data
33
 
34
+ **UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
35
 
36
  **AlphaFold Database MSA & Structure dataset**: We downloaded all the structural data from the AlphaFold Database and only kept the structures where the amino acid ratio of pLDDT>70 was greater than 40%. Then we used `mmseqs` to cluster the remaining sequences with `seq id=0.5`, and retained a representative sequence for each class. Final we get 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain the corresponding structure tokens and structure embedding. And used [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) to obtain the MSA corresponding to the sequence.
37
 
 
45
 
46
  #### (2) UniRef50/Uniclust30 MSA finetuning
47
 
48
+ We used UniRef50/Uniclust30 MSA dataset to finetune the model from stage (1). Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
49
 
50
  #### (3) AFDB MSA & Structure tokens finetuning:
51
 
52
  We fine-tuned a pretrained masked language model using MSA data by concatenating the query sequence with homologous sequences. The input structure embedding (hidden dimension 384) is linearly mapped to 2304 and then added to the corresponding embedding of the query sequence tokens.
53
 
54
+ **Ssequence masking**: We introduced several modifications to the standard BERT masking strategy: (1) We randomly sampled `0.05×L` span positions from a query sequence of length `L`, with span lengths following a geometric distribution (`p=0.2`), and capped the maximum length at 10. Our experiments revealed that this settings lead to an average of 15% of the query tokens were masked. (2) To prevent information leakage, when a residue was selected, all residues at the same index across all sequences (the column of the MSA matrix) were also masked. (3) When a column of MSA was selected for masking, the entire column was replaced with the `<MASK>` token in 80% of cases, with random amino acids in 10% of cases, and remained unchanged in the remaining 10% of cases.
55
 
56
+ **Structure masking**: In 20% of the cases, we randomly replaced the structure embedding with 0; in 80% of the cases, we randomly sampled the number of amino acids using the BetaLinear30 distribution and replaced the structure embedding with 0. The BetaLinear30 distribution is defined as a combination of 20% of the Uniform(0, 1) distribution and 80% of the Beta(3, 9) distribution.
57
 
58
+ **Positional embedding**: To help the model distinguish which tokens are from the same chain and which tokens have the same residue index, we use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to encode the tokens. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
59
 
60
+ **Loss**: The loss function consists of a sequence loss function and a structure loss function (weights are 1.0 and 0.01 respectively). The sequence loss function is the CrossEntropy of recovering the masked sequence tokens, and the structure loss function is the CrossEntropy of predicting the masked structure tokens.
61
 
62
  | Hyper-params | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
63
  | --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
64
+ | Initialized parameters | AIDO.Protein-16B | Stage (1) | Stage (2) |
65
  | Data | ColabFoldDB, UniRef | HHblits_MSA, Retriever_MSA | AFDB MSA & Structure tokens |
66
  | Global Batch Size | 512 | 256 | 256 |
67
  | Sequence length | 2048 | 12800 | 12800 |
68
  | Per Device Micro Batch Size | 1 | 1 | 1 |
69
  | Precision | Mixed FP32-FP16 | Mixed FP32-FP16 | Mixed FP32-FP16 |
70
+ | LR | [5e-6,5e-5] | [1e-6, 1e-5] | 1e-5 |
71
+ | Num Tokens | 10 billion | 100 billion | 80 billion |
72
 
73
  ### Tokenization
74
 
75
  We encode protein sequence with single amino acid resolution with 44 vocabularies, where 24 tokens represent amino acid types and 20 are special tokens. Sequences were also suffixed with a `[SEP]` token as hooks for downstream tasks.
76
 
77
+ ## Results
78
+
79
+ ### Supervised downstream tasks
80
+
81
+ <center><img src="supervised_tasks.png" alt="supervised_tasks" style="width:90%; height:auto;" /></center>
82
+
83
+ ### Supervised DMS fitness score prediction of 25 samples
84
+
85
+ <center><img src="supervised_dms.png" alt="supervised_dms" style="width:90%; height:auto;" /></center>
86
+
87
  ## How to Use
88
 
89
  ### Build any downstream models from this backbone with ModelGenerator
 
104
  from modelgenerator.tasks import Embed
105
  model = Embed.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
106
  model.backbone.max_length = 12800
107
+ data = torch.load("examples.pt", 'cpu')[0]
108
  transformed_batch = model.transform(data)
109
  with torch.no_grad():
110
  embedding = model(transformed_batch)
 
119
  from modelgenerator.tasks import SequenceClassification
120
  model = SequenceClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 2}).eval()
121
  model.backbone.max_length = 12800
122
+ data = torch.load("examples.pt", 'cpu')[0]
123
  transformed_batch = model.transform(data)
124
  with torch.no_grad():
125
  logits = model(transformed_batch)
 
135
  from modelgenerator.tasks import TokenClassification
136
  model = TokenClassification.from_config({"model.backbone": "aido_ragprotein_16b", "model.n_classes": 3}).eval()
137
  model.backbone.max_length = 12800
138
+ data = torch.load("examples.pt", 'cpu')[0]
139
  transformed_batch = model.transform(data)
140
  with torch.no_grad():
141
  logits = model(transformed_batch)
 
150
  from modelgenerator.tasks import SequenceRegression
151
  model = SequenceRegression.from_config({"model.backbone": "aido_ragprotein_16b"}).eval()
152
  model.backbone.max_length = 12800
153
+ data = torch.load("examples.pt", 'cpu')[0]
154
  transformed_batch = model.transform(data)
155
  with torch.no_grad():
156
  logits = model(transformed_batch)
examples.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e60409786b001317d5150ab440521fa1f9a6a90ca18f4666e27740b4e6a75aa5
3
+ size 5485326