File size: 5,306 Bytes
0bf33af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import json
import regex as re
from typing import Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AksharaTokenizer(PreTrainedTokenizer):
"""
Akshara tokenizer for processing Indic language text.
This tokenizer handles characters at the akshara (syllable) level.
"""
vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file=None,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
mask_token="<mask>",
add_prefix_space=False,
**kwargs
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
# Load vocabulary
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
# Load merges if available
self.merges = {}
if merges_file is not None and os.path.isfile(merges_file):
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")
self.merges = {tuple(merge.split()): i for i, merge in enumerate(merges) if merge}
# Special token handling
self.add_prefix_space = add_prefix_space
# Pre-compile regex patterns for tokenization
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def _tokenize(self, text):
"""Tokenize text into akshara units."""
if self.add_prefix_space and not text.startswith(" "):
text = " " + text
tokens = re.findall(self.pat, text)
return tokens
def _convert_token_to_id(self, token):
"""Convert a token to its ID in the vocabulary."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Convert an ID to its token in the vocabulary."""
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
"""Convert a sequence of tokens to a single string."""
text = "".join(tokens)
text = text.replace(" ", "").replace("▁", " ").strip()
return text
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""Get list where entries are [1] if a token is special and [0] otherwise."""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""Create a mask from the two sequences for sequence classification tasks."""
eos = [self.eos_token_id]
bos = [self.bos_token_id]
if token_ids_1 is None:
return len(bos + token_ids_0 + eos) * [0]
return len(bos + token_ids_0 + eos) * [0] + len(token_ids_1 + eos) * [1]
def save_vocabulary(self, save_directory, filename_prefix=None):
"""Save the vocabulary and merges files to a directory."""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
return (vocab_file,)
# Register the tokenizer with the AutoTokenizer class
from transformers import AutoTokenizer
AutoTokenizer.register("akshara", AksharaTokenizer)
# Register the model configuration if needed
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
if "akshara" not in CONFIG_MAPPING:
from transformers import PretrainedConfig
class AksharaConfig(PretrainedConfig):
model_type = "akshara"
CONFIG_MAPPING.register("akshara", AksharaConfig) |