import torch from collections import deque from jamotools import split_syllables, join_jamos from transformers import PretrainedConfig, PreTrainedModel, AutoTokenizer class HangulTokenizerConfig(PretrainedConfig): model_type = "hangul_tokenizer" def __init__( self, base_tokenizer_name='unsloth/gemma-2-2b', **kwargs ): super().__init__(**kwargs) self.base_tokenizer_name = base_tokenizer_name class HangulTokenizer(PreTrainedModel): config_class = HangulTokenizerConfig def __init__(self, config): super().__init__(config) self.temp_module = torch.nn.Parameter(torch.ones(1)) self.base_tokenizer = AutoTokenizer.from_pretrained(config.base_tokenizer_name) self.base_tokenizer.pad_token_id = 128 self.base_tokenizer.pad_token = self.base_tokenizer.decode([self.base_tokenizer.pad_token_id]) self.space_token_id = self.base_tokenizer.encode(' ', add_special_tokens=False)[-1] char_start, char_end = 0xAC00, 0xD7A3 # 가-힣 self.kor_chars = list(set([chr(code) for code in range(char_start, char_end + 1)])) self.char_3ids = [] self.char_1ids = [] for kor_char in self.kor_chars: ids = self.base_tokenizer.encode(kor_char, add_special_tokens=False) if len(ids)==3: self.char_3ids.append(ids) else: ids = ids+2*[self.base_tokenizer.pad_token_id] self.char_1ids.append(ids) self.chos = ['ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] self.joongs = ['ㅏ', 'ㅐ', 'ㅑ', 'ㅒ', 'ㅓ', 'ㅔ', 'ㅕ', 'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅠ', 'ㅡ', 'ㅢ', 'ㅣ'] self.jongs = [self.base_tokenizer.pad_token, 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ'] jamos = list(set(self.chos) | set(self.joongs) | set(self.jongs)) jamo_ids = self.base_tokenizer(jamos, add_special_tokens=False)['input_ids'] self.jamo_to_id = {jamo: jamo_id[-1] for jamo, jamo_id in zip(jamos, jamo_ids)} self.cho_ids = [self.jamo_to_id[cho] for cho in self.chos] self.joong_ids = [self.jamo_to_id[joong] for joong in self.joongs] self.jong_ids = [self.jamo_to_id[jong] for jong in self.jongs] self.id_to_jamo = {jamo_id: jamo for jamo, jamo_id in self.jamo_to_id.items()} def encode_jamo(self, sentence): encoded_ids = [] token_type_ids = [] past_chars = '' for char in sentence: if char in self.kor_chars: if past_chars: past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) encoded_ids.extend(past_chars_encoded) token_type_ids.extend([0]*len(past_chars_encoded)) past_chars='' char_splitted = list(split_syllables(char))[:3] char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token] cho, joong, jong = char_splitted encoded_ids.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]]) token_type_ids.extend([1,2,3]) else: past_chars = past_chars+char if past_chars: past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) encoded_ids.extend(past_chars_encoded) token_type_ids.extend([0]*len(past_chars_encoded)) return encoded_ids, token_type_ids def decode_jamo(self, encoded_ids, token_type_ids): encoded_ids = deque(encoded_ids) token_type_ids = deque(token_type_ids) decoded = [] past_ids = [] while len(encoded_ids): encoded_id = encoded_ids.popleft() token_type_id = token_type_ids.popleft() if token_type_id==0: past_ids.append(encoded_id) else: decoded.append(self.base_tokenizer.decode(past_ids)) past_ids = [] cho_id = encoded_id joong_id = encoded_ids.popleft() jong_id = encoded_ids.popleft() token_type_ids.popleft() token_type_ids.popleft() char = join_jamos([self.id_to_jamo[cho_id], self.id_to_jamo[joong_id], self.id_to_jamo[jong_id]])[:1] decoded.append(char) decoded.append(self.base_tokenizer.decode(past_ids)) return ''.join(decoded) def encode_char(self, sentence): encoded_ids = [] token_type_ids = [] past_chars = '' for char in sentence: if char in self.kor_chars: if past_chars: past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) encoded_ids.extend(past_chars_encoded) token_type_ids.extend([0]*len(past_chars_encoded)) past_chars='' encoded_id = self.base_tokenizer.encode(char, add_special_tokens=False) encoded_id = encoded_id + (3-len(encoded_id)) * [self.base_tokenizer.pad_token_id] encoded_ids.extend(encoded_id) token_type_ids.extend([4,4,4]) else: past_chars = past_chars+char if past_chars: past_chars_encoded = self.base_tokenizer.encode(past_chars, add_special_tokens=False) encoded_ids.extend(past_chars_encoded) token_type_ids.extend([0]*len(past_chars_encoded)) return encoded_ids, token_type_ids def decode_char(self, encoded_ids, token_type_ids): encoded_ids = deque(encoded_ids) token_type_ids = deque(token_type_ids) decoded = [] past_ids = [] while len(encoded_ids): encoded_id = encoded_ids.popleft() token_type_id = token_type_ids.popleft() if token_type_id==0: past_ids.append(encoded_id) else: decoded.append(self.base_tokenizer.decode(past_ids)) past_ids = [] id1 = encoded_id id2 = encoded_ids.popleft() id3 = encoded_ids.popleft() token_type_ids.popleft() token_type_ids.popleft() [id1, id2, id3] char = self.base_tokenizer.decode([id1, id2, id3])[:1] decoded.append(char) decoded.append(self.base_tokenizer.decode(past_ids)) return ''.join(decoded) def encode_jamo_from_char_encoded(self, encoded_ids, token_type_ids): encoded_ids = deque(encoded_ids) token_type_ids = deque(token_type_ids) encoded_ids_new = [] token_type_ids_new = [] while len(encoded_ids): encoded_id = encoded_ids.popleft() token_type_id = token_type_ids.popleft() if token_type_id==0: encoded_ids_new.append(encoded_id) token_type_ids_new.append(token_type_id) else: encoded_id2 = encoded_ids.popleft() encoded_id3 = encoded_ids.popleft() token_type_ids.popleft() token_type_ids.popleft() char = self.base_tokenizer.decode([encoded_id, encoded_id2, encoded_id3])[0] char_splitted = list(split_syllables(char))[:3] char_splitted = char_splitted + (3-len(char_splitted))*[self.base_tokenizer.pad_token] cho, joong, jong = char_splitted encoded_ids_new.extend([self.jamo_to_id[cho], self.jamo_to_id[joong], self.jamo_to_id[jong]]) token_type_ids_new.extend([1,2,3]) return encoded_ids_new, token_type_ids_new def batch_encode_char(self, sentences): input_ids = [] attention_mask = [] token_type_ids = [] for sentence in sentences: input_ids_row, token_type_id = self.encode_char(sentence) input_ids.append(input_ids_row) token_type_ids.append(token_type_id) max_length = max(list(map(len, input_ids))) for i in range(len(sentences)): input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id] attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]]) token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0] return ( torch.LongTensor(input_ids), torch.LongTensor(attention_mask), torch.LongTensor(token_type_ids) ) def batch_encode_jamo_from_char_encoded(self, batch_encoded_ids, batch_token_type_ids): input_ids = [] attention_mask = [] token_type_ids_new = [] for encoded_ids, token_type_ids in zip(batch_encoded_ids, batch_token_type_ids): encoded_ids_row, token_type_ids_row = self.encode_jamo_from_char_encoded(encoded_ids, token_type_ids) attention_mask.append([1 if encoded_id!=self.base_tokenizer.eos_token_id else 0 for encoded_id in encoded_ids_row]) input_ids.append(encoded_ids_row) token_type_ids_new.append(token_type_ids_row) return ( torch.LongTensor(input_ids), torch.LongTensor(attention_mask), torch.LongTensor(token_type_ids_new), ) def batch_encode_jamo(self, sentences): input_ids = [] attention_mask = [] token_type_ids = [] for sentence in sentences: input_ids_row, token_type_id = self.encode_jamo(sentence) input_ids.append(input_ids_row) token_type_ids.append(token_type_id) max_length = max(list(map(len, input_ids))) for i in range(len(sentences)): input_ids[i] = input_ids[i] + (max_length-len(input_ids[i])) * [self.base_tokenizer.eos_token_id] attention_mask.append([1 if input_id!=self.base_tokenizer.eos_token_id else 0 for input_id in input_ids[i]]) token_type_ids[i] = token_type_ids[i] + (max_length-len(token_type_ids[i])) * [0] return ( torch.LongTensor(input_ids), torch.LongTensor(attention_mask), torch.LongTensor(token_type_ids), )