import os
import logging
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
class XLMRTokenizer(PreTrainedTokenizer):
def __init__(self, bpe_file, dict_file, **kwargs):
super(XLMRTokenizer, self).__init__(
bos_token="",
eos_token="",
unk_token="",
pad_token="",
mask_token="",
sep_token="",
cls_token="",
**kwargs)
self.max_len_single_sentence = self.max_len - 2
self.max_len_sentences_pair = self.max_len - 4
self.sp = spm.SentencePieceProcessor()
self.sp.Load(bpe_file)
self.encoder = {}
self.decoder = []
for token in [self.bos_token, self.pad_token, self.eos_token, self.unk_token]:
self._add_token(token)
with open(dict_file, encoding="utf-8") as fp:
for line in fp:
# NOTE DO NOT USE .split()
tokens_cnt = line.rstrip().split(" ")
try:
assert len(tokens_cnt) >= 2, line
except AssertionError:
logger.error(
"tokenizer line %s asserterror, replaced as " % (
line, len(self.decoder)))
exit(0)
self._add_token(" ".join(tokens_cnt[:-1]))
def _add_token(self, token):
idx = len(self.encoder)
self.encoder[token] = idx
self.decoder.append(token)
def _tokenize(self, text):
return self.sp.EncodeAsPieces(text)
def _convert_id_to_token(self, index):
return self.decoder[index]
def _convert_token_to_id(self, token):
return self.encoder.get(token, self.encoder.get(self.unk_token))
def convert_tokens_to_string(self, tokens):
return "".join(tokens).replace('\u2581', ' ').strip()
@classmethod
def from_pretrained(cls, model_path, **kwargs):
bpe_file = os.path.join(model_path, "sentencepiece.bpe.model")
dict_file = os.path.join(model_path, "dict.txt")
tokenizer = cls(bpe_file, dict_file)
return tokenizer
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of ids is already formated with special tokens for the model.")
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(sep + token_ids_1 + sep) * [1]
if __name__ == "__main__":
tokenizer = XLMRTokenizer.from_pretrained("/home/v-zechi/data/unilm/zechi/exp/bert_data/xlmr-large")
for text in ["Hello world!", "你好,世界", "नमस्ते दुनिया", "مرحبا بالعالم", "Bonjour le monde"]:
print(tokenizer.tokenize(text))
print(tokenizer.encode_plus(text, text, add_special_tokens=True))