import random import torch import torch.nn as nn from jamotools import split_syllables, join_jamos from transformers import PretrainedConfig, PreTrainedModel, AutoConfig class HangulAugmentatorConfig(PretrainedConfig): model_type = "hangul_augmentator" def __init__( self, p=0.5, do_link_next=True, do_link_before=True, do_replace_similar=True, do_add_jong=True, **kwargs ): super().__init__(**kwargs) self.p = p self.do_link_next = do_link_next self.do_link_before = do_link_before self.do_replace_similar = do_replace_similar self.do_add_jong = do_add_jong class HangulAugmentator(PreTrainedModel): config_class = HangulAugmentatorConfig def __init__(self, config): super().__init__(config) self.temp_module = torch.nn.Parameter(torch.ones(1)) self.ja_similar_dict = { 'ㅂ': ['ㅃ', 'ㅍ'], 'ㄱ': ['ㄲ', 'ㅋ'], 'ㄷ': ['ㄸ', 'ㅌ'], 'ㄲ': ['ㄲ', 'ㅋ'], 'ㅅ': ['ㅆ'], 'ㅈ': ['ㅉ', 'ㅊ'], 'ㅌ': ['ㄸ', 'ㅌ', 'ㄷ'], 'ㅋ': ['ㄲ', 'ㄱ'], 'ㅍ': ['ㅃ', 'ㅂ'], 'ㅃ': ['ㅍ', 'ㅂ'], 'ㄸ': ['ㅌ', 'ㄷ'], 'ㅊ': ['ㅉ', 'ㅉ', 'ㅈ'], 'ㅆ': ['ㅅ'], 'ㅉ': ['ㅉ', 'ㅈ'], } self.mo_similar_dict = { 'ㅕ': ['ㅓ'], 'ㅏ': ['ㅑ'], 'ㅐ': ['ㅔ', 'ㅒ'], 'ㅗ': ['ㅛ'], 'ㅙ': ['ㅚ', 'ㅞ'], 'ㅡ': ['ㅜ'], 'ㅣ': ['ㅟ'], 'ㅜ': ['ㅠ'], 'ㅓ': [ 'ㅕ'], 'ㅔ': ['ㅖ', 'ㅞ'], 'ㅛ': ['ㅗ'], 'ㅚ': ['ㅙ', 'ㅞ'], 'ㅠ': ['ㅜ'], 'ㅝ': ['ㅓ'], 'ㅖ': ['ㅒ'], 'ㅢ': ['ㅟ'], 'ㅑ': ['ㅏ'], 'ㅞ': ['ㅙ', 'ㅚ', 'ㅔ'], 'ㅒ': ['ㅖ'] } self.jong_link_dict = { 'ㅂ': ['ㅂ', 'ㅂ'], 'ㅍ': ['ㅍ', 'ㅍ'], 'ㄱ': ['ㄱ', 'ㄱ'], 'ㅊ': ['ㅊ', 'ㅊ'], 'ㅎ': ['ㅎ', 'ㅎ'], 'ㅇ': ['ㅇ', 'ㅇ'], 'ㅌ': ['ㅌ', 'ㅌ'], 'ㄽ': ['ㄹ', 'ㅅ'], 'ㄿ': ['ㄹ', 'ㅍ'], 'ㄵ': ['ㄴ', 'ㅈ'], 'ㄲ': ['ㄲ', 'ㄲ'], 'ㅋ': ['ㅋ', 'ㅋ'], 'ㄴ': ['ㄴ', 'ㄴ'], 'ㄷ': ['ㄷ', 'ㄷ'], 'ㅀ': ['ㄹ', 'ㅎ'], 'ㅈ': ['ㅈ', 'ㅈ'], 'ㄺ': ['ㄹ', 'ㄱ'], 'ㄼ': ['ㄹ', 'ㅂ'], 'ㅅ': ['ㅅ', 'ㅅ'], 'ㄶ': ['ㄴ', 'ㅎ'], 'ㄹ': ['ㄹ', 'ㄹ'], 'ㅁ': ['ㅁ', 'ㅁ'], 'ㄳ': ['ㄱ', 'ㅅ'], 'ㅆ': ['ㅆ', 'ㅆ'], 'ㄾ': ['ㄹ', 'ㅌ'], 'ㅄ': ['ㅂ', 'ㅅ'], 'ㄻ': ['ㄹ', 'ㅁ'] } self.jong_similar_dict = { 'ㄹ': [ 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㄲ': [ 'ㄱ', 'ㄳ', 'ㄴ', 'ㄹ', 'ㄻ', 'ㅁ', 'ㅂ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅋ', 'ㅌ', 'ㅎ' ], 'ㅅ': [ 'ㄱ', 'ㄲ', 'ㄴ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄽ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㅁ': [ 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㄴ': [ 'ㄱ', 'ㄲ', 'ㄳ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄾ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㅇ': ['ㅎ'], 'ㅆ': [ 'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㅄ': ['ㅂ', 'ㅍ'], 'ㅂ': [ 'ㄱ', 'ㄴ', 'ㄵ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㅀ', 'ㅁ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㄶ': ['ㄴ', 'ㄵ'], 'ㅍ': [ 'ㄱ', 'ㄲ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅌ', 'ㅎ' ], 'ㄱ': [ 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ' ], 'ㅌ': [ 'ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅍ', 'ㅎ' ], 'ㄼ': ['ㄹ', 'ㄺ', 'ㄻ', 'ㅀ'], 'ㄷ': [ 'ㄱ', 'ㄲ', 'ㄴ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅌ' ], 'ㅈ': [ 'ㄱ', 'ㄴ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㅁ', 'ㅂ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅊ', 'ㅌ', 'ㅎ' ], 'ㅎ': ['ㄱ', 'ㄳ', 'ㄴ', 'ㄹ', 'ㄻ', 'ㅁ', 'ㅂ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅍ'], 'ㅊ': ['ㄷ', 'ㅁ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅌ'], 'ㄵ': ['ㄴ', 'ㄶ'], 'ㄻ': ['ㄹ', 'ㄺ', 'ㄽ', 'ㅀ'], 'ㅀ': ['ㄹ', 'ㄺ', 'ㄻ'], 'ㄺ': ['ㄹ', 'ㄻ', 'ㄼ', 'ㅀ'], 'ㅋ': ['ㄱ', 'ㄳ', 'ㅄ'], 'ㄳ': ['ㄱ'], 'ㄾ': ['ㄹ'], } def __call__(self, sentence): if self.config.do_link_next: sentence = self.link_next(sentence, self.config.p) if self.config.do_link_before: sentence = self.link_before(sentence, self.config.p) if self.config.do_replace_similar: chars = [] for char in sentence: if self.config.p>=random.random(): chars.append(self.replace_similar(char)) else: chars.append(char) sentence = ''.join(chars) if self.config.do_add_jong: chars = [] for char in sentence: if self.config.p>=random.random(): chars.append(self.add_jong(char)) else: chars.append(char) sentence = ''.join(chars) return sentence def _link_next(self, char1, char2): if len((char1+char2).strip())!=2: return char1, char2 if not (0xAC00<= ord(char1) <=0xD7A3 and 0xAC00<= ord(char2) <=0xD7A3): return char1, char2 char1_jamo = list(split_syllables(char1)) if len(char1_jamo)!=3: return char1, char2 char2_jamo = list(split_syllables(char2)) if char2_jamo[0]!='ㅇ': return char1, char2 new_jong, new_cho = self.jong_link_dict[char1_jamo[-1]] new_char1 = join_jamos(char1_jamo[:2] + [new_jong])[:1] new_char2 = join_jamos([new_cho] + char2_jamo[1:])[:1] return new_char1, new_char2 def link_next(self, sentence, p): chars = list(sentence) for i in range(len(chars)-1): if p>=random.random(): new_char1, new_char2 = self._link_next(chars[i], chars[i+1]) chars[i], chars[i+1] = new_char1, new_char2 new_sentence = ''.join(chars) return new_sentence def _link_before(self, char1, char2): if len((char1+char2).strip())!=2: return char1, char2 if not (0xAC00<= ord(char1) <=0xD7A3 and 0xAC00<= ord(char2) <=0xD7A3): return char1, char2 char1_jamo = list(split_syllables(char1)) if len(char1_jamo)!=2: return char1, char2 char2_jamo = list(split_syllables(char2)) new_char1 = join_jamos(char1_jamo[:2] + char2_jamo[:1])[:1] return new_char1, char2 def link_before(self, sentence, p): chars = list(sentence) for i in range(len(chars)-1): if p>=random.random(): new_char1, new_char2 = self._link_before(chars[i], chars[i+1]) chars[i], chars[i+1] = new_char1, new_char2 new_sentence = ''.join(chars) return new_sentence def replace_similar(self, char): if len(char.strip())!=1: return char if not 0xAC00<= ord(char) <=0xD7A3: return char jamo = list(split_syllables(char)) jamo[0] = random.choice(self.ja_similar_dict.get(jamo[0],jamo[0])) jamo[1] = random.choice(self.mo_similar_dict.get(jamo[1],jamo[1])) if len(jamo)==3: jamo[2] = random.choice(self.jong_similar_dict.get(jamo[2],jamo[2])) return join_jamos(jamo)[:1] def add_jong(self, char): if len(char.strip())!=1: return char if not 0xAC00<= ord(char) <=0xD7A3: return char jamo = list(split_syllables(char)) if len(jamo)==3: return char new_jong = random.choice(list(self.jong_link_dict.keys())) new_char = join_jamos(jamo[:2]+[new_jong])[:1] return new_char def _link_next(self, char1, char2): if len((char1+char2).strip())!=2: return char1, char2 if not (0xAC00<= ord(char1) <=0xD7A3 and 0xAC00<= ord(char2) <=0xD7A3): return char1, char2 char1_jamo = list(split_syllables(char1)) if len(char1_jamo)!=3: return char1, char2 char2_jamo = list(split_syllables(char2)) if char2_jamo[0]!='ㅇ': return char1, char2 new_jong, new_cho = self.jong_link_dict[char1_jamo[-1]] new_char1 = join_jamos(char1_jamo[:2] + [new_jong])[:1] new_char2 = join_jamos([new_cho] + char2_jamo[1:])[:1] return new_char1, new_char2 def link_next(self, sentence, p): chars = list(sentence) for i in range(len(chars)-1): if p>=random.random(): new_char1, new_char2 = self._link_next(chars[i], chars[i+1]) chars[i], chars[i+1] = new_char1, new_char2 new_sentence = ''.join(chars) return new_sentence def _link_before(self, char1, char2): if len((char1+char2).strip())!=2: return char1, char2 if not (0xAC00<= ord(char1) <=0xD7A3 and 0xAC00<= ord(char2) <=0xD7A3): return char1, char2 char1_jamo = list(split_syllables(char1)) if len(char1_jamo)!=2: return char1, char2 char2_jamo = list(split_syllables(char2)) new_char1 = join_jamos(char1_jamo[:2] + char2_jamo[:1])[:1] return new_char1, char2 def link_before(self, sentence, p): chars = list(sentence) for i in range(len(chars)-1): if p>=random.random(): new_char1, new_char2 = self._link_before(chars[i], chars[i+1]) chars[i], chars[i+1] = new_char1, new_char2 new_sentence = ''.join(chars) return new_sentence def replace_similar(self, char): if len(char.strip())!=1: return char if not 0xAC00<= ord(char) <=0xD7A3: return char jamo = list(split_syllables(char)) jamo[0] = random.choice(self.ja_similar_dict.get(jamo[0],jamo[0])) jamo[1] = random.choice(self.mo_similar_dict.get(jamo[1],jamo[1])) if len(jamo)==3: jamo[2] = random.choice(self.jong_similar_dict.get(jamo[2],jamo[2])) return join_jamos(jamo)[:1] def add_jong(self, char): if len(char.strip())!=1: return char if not 0xAC00<= ord(char) <=0xD7A3: return char jamo = list(split_syllables(char)) if len(jamo)==3: return char new_jong = random.choice(list(self.jong_link_dict.keys())) new_char = join_jamos(jamo[:2]+[new_jong])[:1] return new_char