oucgc1996 commited on
Commit
6e03af6
·
verified ·
1 Parent(s): 5104b9a

Upload vocab.py

Browse files
Files changed (1) hide show
  1. vocab.py +193 -0
vocab.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+
4
+ class PepVocab:
5
+ def __init__(self):
6
+ self.token_to_idx = {
7
+ '<MASK>': -1, '<PAD>': 0, 'A': 1, 'C': 2, 'E': 3, 'D': 4, 'F': 5, 'I': 6, 'H': 7,
8
+ 'K': 8, 'M': 9, 'L': 10, 'N': 11, 'Q': 12, 'P': 13, 'S': 14,
9
+ 'R': 15, 'T': 16, 'W': 17, 'V': 18, 'Y': 19, 'G': 20, 'O': 21, 'U': 22, 'Z': 23, 'X': 24}
10
+ self.idx_to_token = {
11
+ -1: '<MASK>', 0: '<PAD>', 1: 'A', 2: 'C', 3: 'E', 4: 'D', 5: 'F', 6: 'I', 7: 'H',
12
+ 8: 'K', 9: 'M', 10: 'L', 11: 'N', 12: 'Q', 13: 'P', 14: 'S',
13
+ 15: 'R', 16: 'T', 17: 'W', 18: 'V', 19: 'Y', 20: 'G', 21: 'O', 22: 'U', 23: 'Z', 24: 'X'}
14
+
15
+ self.get_attention_mask = False
16
+ self.attention_mask = []
17
+
18
+ def set_get_attn(self, is_get: bool):
19
+ self.get_attention_mask = is_get
20
+
21
+ def __len__(self):
22
+ return len(self.idx_to_token)
23
+
24
+ def __getitem__(self, tokens):
25
+ '''
26
+ note: input should a splited sequence
27
+
28
+ Args:
29
+ tokens: a token or token list of splited
30
+ '''
31
+ if not isinstance(tokens, (list, tuple)):
32
+ # return self.token_to_idx.get(tokens)
33
+ return self.token_to_idx[tokens]
34
+ return [self.__getitem__(token) for token in tokens]
35
+
36
+ def vocab_from_txt(self, path):
37
+ '''
38
+ note: this function use for constructing vocab mapping
39
+ but it is only suitable for special txt format
40
+ it support one column txt file, which column name is 0
41
+ '''
42
+ token_to_idx = {}
43
+ idx_to_token = {}
44
+ chr_idx = pd.read_csv(path, header=None, sep='\t')
45
+ if chr_idx.shape[1] == 1:
46
+ for idx, token in enumerate(chr_idx[0]):
47
+ token_to_idx[token] = idx
48
+ idx_to_token[idx] = token
49
+ self.token_to_idx = token_to_idx
50
+ self.idx_to_token = idx_to_token
51
+
52
+ def to_tokens(self, indices):
53
+ '''
54
+ note: input should a integer list
55
+ '''
56
+ if hasattr(indices, '__len__') and len(indices) > 1:
57
+ return [self.idx_to_token[int(index)] for index in indices]
58
+ return self.idx_to_token[indices]
59
+
60
+ def add_special_token(self, token: str|list|tuple) -> None:
61
+ if not isinstance(token, (list, tuple)):
62
+ if token in self.token_to_idx:
63
+ raise ValueError(f"token {token} already in the vocab")
64
+ self.idx_to_token[len(self.idx_to_token)] = token
65
+ self.token_to_idx[token] = len(self.token_to_idx)
66
+ else:
67
+ [self.add_special_token(t) for t in token]
68
+
69
+ def split_seq(self, seq: str|list|tuple) -> list:
70
+ if not isinstance(seq, (list, tuple)):
71
+ return re.findall(r"<[a-zA-Z0-9]+>|[a-zA-Z-]", seq)
72
+ return [self.split_seq(s) for s in seq] # a list of list
73
+
74
+ def truncate_pad(self, line, num_steps, padding_token='<PAD>') -> list:
75
+
76
+ if not isinstance(line[0], list):
77
+ if len(line) > num_steps:
78
+ if self.get_attention_mask:
79
+ self.attention_mask.append([1]*num_steps)
80
+ return line[:num_steps]
81
+ if self.get_attention_mask:
82
+ self.attention_mask.append([1] * len(line) + [0] * (num_steps - len(line)))
83
+ return line + [padding_token] * (num_steps - len(line))
84
+ else:
85
+ return [self.truncate_pad(l, num_steps, padding_token) for l in line] # a list of list
86
+
87
+ def get_attention_mask_mat(self):
88
+ attention_mask = self.attention_mask
89
+ self.attention_mask = []
90
+ return attention_mask
91
+
92
+ def seq_to_idx(self, seq: str|list|tuple, num_steps: int, padding_token='<PAD>') -> list:
93
+ '''
94
+ note: ensure to execut this function after add_special_token
95
+ '''
96
+
97
+ splited_seq = self.split_seq(seq)
98
+ # **********************
99
+ # after split, we need to mask sequence
100
+ # note:
101
+ # 1. mask tokens by probability
102
+ # 2. return a list or list of list
103
+ # **********************
104
+ padded_seq = self.truncate_pad(splited_seq, num_steps, padding_token)
105
+
106
+ return self.__getitem__(padded_seq)
107
+
108
+
109
+
110
+ class MutilVocab:
111
+ def __init__(self, data, AA_tok_len=2):
112
+ """
113
+ Args:
114
+ data (_type_):
115
+ AA_tok_len (int, optional): Defaults to 1.
116
+ start_token (bool, optional): True is required for encoder-based model.
117
+ """
118
+ ## Load train dataset
119
+ self.x_data = data
120
+ self.tok_AA_len = AA_tok_len
121
+ self.default_AA = list("RHKDESTNQCGPAVILMFYW")
122
+ # AAs which are not included in default_AA
123
+ self.tokens = self._token_gen(self.tok_AA_len)
124
+
125
+ self.token_to_idx = {k: i + 4 for i, k in enumerate(self.tokens)}
126
+ self.token_to_idx["[PAD]"] = 0 ## idx as 0 is PAD
127
+ self.token_to_idx["[CLS]"] = 1 ## idx as 1 is CLS
128
+ self.token_to_idx["[SEP]"] = 2 ## idx as 2 is SEP
129
+ self.token_to_idx["[MASK]"] = 3 ## idx as 3 is MASK
130
+
131
+ def split_seq(self):
132
+ self.X = [self._seq_to_tok(seq) for seq in self.x_data]
133
+ return self.X
134
+
135
+ def tok_idx(self, seqs):
136
+ '''
137
+ note: ensure to execut this function before truancate_pad
138
+ '''
139
+
140
+ seqs_idx = []
141
+ for seq in seqs:
142
+ seq_idx = []
143
+ for s in seq:
144
+ seq_idx.append(self.token_to_idx[s])
145
+ seqs_idx.append(seq_idx)
146
+
147
+ return seqs_idx
148
+
149
+
150
+
151
+ def _token_gen(self, tok_AA_len: int, st: str = "", curr_depth: int = 0):
152
+ """Generate tokens based on default amino acid residues
153
+ and also includes "X" as arbitrary residues.
154
+ Length of AAs in each token should be provided by "tok_AA_len"
155
+
156
+ Args:
157
+ tok_AA_len (int): Length of token
158
+ st (str, optional): Defaults to ''.
159
+ curr_depth (int, optional): Defaults to 0.
160
+
161
+ Returns:
162
+ List: List of tokens
163
+ """
164
+ curr_depth += 1
165
+ if curr_depth <= tok_AA_len:
166
+ l = [
167
+ st + t
168
+ for s in self.default_AA
169
+ for t in self._token_gen(tok_AA_len, s, curr_depth)
170
+ ]
171
+ return l
172
+ else:
173
+ return [st]
174
+
175
+ def _seq_to_tok(self, seq: str):
176
+ """Convert each token to index
177
+
178
+ Args:
179
+ seq (str): AA sequence
180
+
181
+ Returns:
182
+ list: A list of indexes
183
+ """
184
+
185
+ seq_idx = []
186
+
187
+ seq_idx += ["[CLS]"]
188
+
189
+ for i in range(len(seq) - self.tok_AA_len + 1):
190
+ curr_token = seq[i : i + self.tok_AA_len]
191
+ seq_idx.append(curr_token)
192
+ seq_idx += ['[SEP]']
193
+ return seq_idx