|
import logging |
|
import os |
|
import json |
|
from typing import Optional, Dict, List, Set, Tuple, Union, Literal, Type |
|
from pydantic.dataclasses import dataclass |
|
|
|
import numpy as np |
|
from numpy.typing import NDArray |
|
|
|
from transformers import PreTrainedTokenizerFast |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
VOCAB_FILES_NAMES = { |
|
"tag_category": "tag_category.json", |
|
} |
|
|
|
PRETRAINED_VOCAB_FILES_MAP = { |
|
"tag_category": { |
|
"p1atdev/tokenizer_test_1": "https://huggingface.co/p1atdev/tokenizer_test_1/resolve/main/tag_category.json" |
|
} |
|
} |
|
|
|
|
|
@dataclass |
|
class Category: |
|
name: str |
|
max_count: Optional[int] |
|
next_category: List[int] |
|
can_end: bool |
|
bos_token_id: int |
|
eos_token_id: int |
|
default_mask: int |
|
|
|
|
|
@dataclass |
|
class SpecialMapping: |
|
allow: List[int] |
|
disallow: List[int] |
|
|
|
|
|
@dataclass |
|
class TagCategoryConfig: |
|
start_category: int |
|
categories: Dict[str, Category] |
|
special_mapping: Dict[ |
|
str, Dict[str, SpecialMapping] |
|
] |
|
category_tags_pairs: Dict[str, List[int]] |
|
|
|
|
|
class OverrideMask: |
|
allow: np.ndarray |
|
disallow: np.ndarray |
|
|
|
def __init__(self, allow: np.ndarray, disallow: np.ndarray) -> None: |
|
self.allow = allow |
|
self.disallow = disallow |
|
|
|
|
|
def load_tag_category(config_json: str): |
|
with open(config_json, "rb") as file: |
|
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read())) |
|
|
|
return config |
|
|
|
|
|
class DartTokenizer(PreTrainedTokenizerFast): |
|
"""Dart tokenizer""" |
|
|
|
vocab_files_names = VOCAB_FILES_NAMES |
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
|
|
|
def __init__(self, tag_category, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.tag_category_config = load_tag_category(tag_category) |
|
|
|
self.category_bos_map = { |
|
category.bos_token_id: category_id |
|
for category_id, category in self.tag_category_config.categories.items() |
|
} |
|
self.category_eos_map = { |
|
category.eos_token_id: category_id |
|
for category_id, category in self.tag_category_config.categories.items() |
|
} |
|
|
|
self._id_to_category_map = np.zeros(self.vocab_size).astype("uint8") |
|
for category_id, tokens in self.tag_category_config.category_tags_pairs.items(): |
|
self._id_to_category_map[tokens] = int(category_id) |
|
|
|
self.category_mask = self.create_category_vocab_mask() |
|
|
|
def create_vocab_mask(self, value: int = 1): |
|
"""Create an array of vocab size filled with specified value""" |
|
return np.full(self.vocab_size, value).astype("uint8") |
|
|
|
def create_category_vocab_mask(self): |
|
"""Create vocab masks for each category""" |
|
return { |
|
category_id: self.create_vocab_mask( |
|
value=category.default_mask, |
|
) |
|
for category_id, category in self.tag_category_config.categories.items() |
|
} |
|
|
|
def get_token_ids_in_category(self, category_id: Union[int, str]): |
|
"""Get token ids in the specified category""" |
|
return self.tag_category_config.category_tags_pairs[str(category_id)] |
|
|
|
def get_category(self, category_id: Union[int, str]): |
|
"""Get the specified category config""" |
|
return self.tag_category_config.categories[str(category_id)] |
|
|
|
def get_special_mapping(self, token_id: Union[int, str]): |
|
"""Get the special mapping of specified token id""" |
|
return self.tag_category_config.special_mapping[str(token_id)] |
|
|
|
def get_banned_tokens_mask(self, tokens: Union[str, List[str], int, List[int]]): |
|
if isinstance(tokens, str): |
|
tokens = [tokens] |
|
elif isinstance(tokens, int): |
|
tokens = [tokens] |
|
elif isinstance(tokens, list): |
|
tokens = [ |
|
self.convert_tokens_to_ids(token) if isinstance(token, str) else token |
|
for token in tokens |
|
] |
|
|
|
assert isinstance(tokens, list) and all( |
|
[isinstance(token, int) for token in tokens] |
|
) |
|
|
|
mask = self.create_vocab_mask(value=1) |
|
mask[tokens] = 0 |
|
|
|
return mask |
|
|
|
def convert_ids_to_category_ids(self, token_ids: Union[int, List[int]]): |
|
return self._id_to_category_map[token_ids] |
|
|
|
def get_next_tokens_mask( |
|
self, |
|
input_ids: List[int], |
|
category_mask: Optional[Dict[str, np.ndarray]] = None, |
|
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: |
|
"""Get the next token's vocab mask and a category mask""" |
|
|
|
if category_mask == None: |
|
category_mask = self.category_mask |
|
|
|
vocab_mask = self.create_vocab_mask(value=0) |
|
|
|
if len(input_ids) == 0: |
|
|
|
vocab_mask[self.bos_token_id] = 1 |
|
|
|
return vocab_mask, category_mask |
|
|
|
|
|
last_token_id = input_ids[-1] |
|
|
|
if last_token_id == self.unk_token_id: |
|
|
|
logger.warning( |
|
"The unk_token was provided! The vocab mask could not be created properly." |
|
) |
|
return self.create_vocab_mask(value=1), category_mask |
|
|
|
|
|
if str(last_token_id) in self.tag_category_config.special_mapping.keys(): |
|
for category_id, mapping in self.get_special_mapping(last_token_id).items(): |
|
|
|
category_mask[category_id][mapping.allow] = 1 |
|
category_mask[category_id][mapping.disallow] = 0 |
|
|
|
if last_token_id == self.bos_token_id: |
|
|
|
start_category_id = self.tag_category_config.start_category |
|
start_category = self.get_category(start_category_id) |
|
|
|
|
|
vocab_mask[start_category.bos_token_id] = 1 |
|
|
|
return vocab_mask, category_mask |
|
|
|
elif last_token_id == self.eos_token_id: |
|
|
|
|
|
vocab_mask[self.pad_token_id] = 1 |
|
|
|
return vocab_mask, category_mask |
|
|
|
elif last_token_id in self.category_bos_map: |
|
|
|
|
|
|
|
current_category_id = self.category_bos_map[last_token_id] |
|
category = self.get_category(current_category_id) |
|
|
|
tokens_in_category = self.get_token_ids_in_category(current_category_id) |
|
vocab_mask[tokens_in_category] = 1 |
|
|
|
vocab_mask *= category_mask[str(current_category_id)] |
|
vocab_mask[category.eos_token_id] = 1 |
|
|
|
return vocab_mask, category_mask |
|
|
|
elif last_token_id in self.category_eos_map: |
|
|
|
|
|
current_category_id = self.category_eos_map[last_token_id] |
|
category = self.get_category(current_category_id) |
|
|
|
if category.can_end: |
|
|
|
vocab_mask[self.eos_token_id] = 1 |
|
|
|
for next_category_id in category.next_category: |
|
|
|
vocab_mask[self.get_category(next_category_id).bos_token_id] = 1 |
|
|
|
return vocab_mask, category_mask |
|
|
|
else: |
|
|
|
current_category_id = self.convert_ids_to_category_ids(last_token_id).item() |
|
tokens_in_category = self.get_token_ids_in_category(current_category_id) |
|
|
|
vocab_mask[tokens_in_category] = 1 |
|
vocab_mask[self.get_category(current_category_id).eos_token_id] = 1 |
|
vocab_mask *= category_mask[str(current_category_id)] |
|
vocab_mask[input_ids] = 0 |
|
|
|
return vocab_mask, category_mask |
|
|