import torch import re from collections import Counter from transformers import PretrainedConfig, PreTrainedModel class SentenceTokenizerConfig(PretrainedConfig): model_type = "sentence_tokenizer" def __init__( self, min_length=32, max_length=64, n_overlap=3, roll=False, **kwargs ): super().__init__(**kwargs) self.min_length = min_length self.max_length = max_length self.n_overlap = n_overlap self.roll = roll class SentenceTokenizer(PreTrainedModel): config_class = SentenceTokenizerConfig def __init__(self, config): super().__init__(config) self.temp_module = torch.nn.Parameter(torch.ones(1)) self.min_length = config.min_length self.max_length = config.max_length self.n_overlap = config.n_overlap self.roll = config.roll def split_text_into_sentences(self, text): split_text = re.split(r'([^가-힣] )', text) split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else []) return split_text def merge_chunks(self, chunks): merged_chunks = [] buffer = "" for chunk in chunks: buffer += chunk if len(buffer) > self.min_length: # If buffer meets the min length, finalize it merged_chunks.append(buffer) buffer = "" # Add any remaining buffer as the last chunk if buffer: merged_chunks.append(buffer) return merged_chunks def merge_chunks_reverse(self, chunks): chunks_reverse = [] for chunk in chunks[::-1]: chunks_reverse.append(chunk[::-1]) merged_chunks = [] buffer = "" for chunk in chunks_reverse: buffer += chunk if len(buffer) > self.min_length: # If buffer meets the min length, finalize it merged_chunks.append(buffer) buffer = "" # Add any remaining buffer as the last chunk if buffer: merged_chunks.append(buffer) res_merged_chunks = [] for chunk in merged_chunks[::-1]: res_merged_chunks.append(chunk[::-1]) return res_merged_chunks def split_text(self, text): words = self.split_space(text) # Step 2: Greedily merge words until the length of the merged text is shorter than max_length splitted_chunks = [] buffer = [] for word in words: buffer.append(word) # Add the word to the buffer merged_text = ''.join(buffer) # If the merged text exceeds max_length, push the current buffer to the result if len(merged_text) > self.max_length: # Remove the last added word and save the current buffer as a chunk buffer.pop() splitted_chunks.append(''.join(buffer)) buffer = [''+word] # Start a new buffer with the current word # Step 3: Append the left over buffer if buffer: splitted_chunks.append(''.join(buffer)) return splitted_chunks def tokenize(self, text): splitted_chunks = [] # Step 1: Split text into sentences sentences = self.split_text_into_sentences(text) for chunk in sentences: if len(chunk)>=self.max_length: splitted_chunks.extend(self.split_text(chunk)) else: splitted_chunks.append(chunk) merged_chunks = self.merge_chunks(splitted_chunks) merged_chunks = self.merge_chunks_reverse(merged_chunks) return merged_chunks def split_space(self, text): split_text = re.split(r'(\s+)', text) # Keep spaces as part of the split parts filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])] return filtered_text def overlap(self, chunks): if not chunks: return [] if self.roll: chunks = [chunks[-1]] + chunks + [chunks[0]] res = [] total_idx = 0 for chunk_idx in range(len(chunks)-1): chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1] chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b) chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap for overlap_idx in range(self.n_overlap): chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx]) overlap = chunk_a_overlap+chunk_b_overlap start = total_idx+len(chunk_a_past) end = start + len(overlap) res.append((start, end, overlap)) total_idx += len(chunk_a) res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1])) return res def decode_overlap(self, chunks): if not chunks: return "" # Determine total length based on the largest end index max_length = max(end for _, end, _ in chunks) # Dictionary to store characters at each index index_char_map = {i: [] for i in range(max_length)} # Populate index_char_map with characters from chunks for start, end, chunk in chunks: for i, char in enumerate(chunk): index = start + i if index < max_length: index_char_map[index].append(char) # Reconstruct text using majority vote reconstructed_text = [] for i in range(max_length): most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0] reconstructed_text.append(most_common_char) res = "".join(reconstructed_text) if self.roll: res = res[len(chunks[0][2]):-len(chunks[-1][2])] return res