|
import torch |
|
import re |
|
|
|
from collections import Counter |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
class SentenceTokenizerConfig(PretrainedConfig): |
|
model_type = "sentence_tokenizer" |
|
def __init__( |
|
self, |
|
min_length=32, |
|
max_length=64, |
|
n_overlap=3, |
|
roll=False, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.min_length = min_length |
|
self.max_length = max_length |
|
self.n_overlap = n_overlap |
|
self.roll = roll |
|
|
|
class SentenceTokenizer(PreTrainedModel): |
|
config_class = SentenceTokenizerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.temp_module = torch.nn.Parameter(torch.ones(1)) |
|
self.min_length = config.min_length |
|
self.max_length = config.max_length |
|
self.n_overlap = config.n_overlap |
|
self.roll = config.roll |
|
|
|
def split_text_into_sentences(self, text): |
|
split_text = re.split(r'([^가-힣] )', text) |
|
split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else []) |
|
|
|
return split_text |
|
|
|
def merge_chunks(self, chunks): |
|
merged_chunks = [] |
|
buffer = "" |
|
|
|
for chunk in chunks: |
|
buffer += chunk |
|
if len(buffer) > self.min_length: |
|
merged_chunks.append(buffer) |
|
buffer = "" |
|
|
|
|
|
if buffer: |
|
merged_chunks.append(buffer) |
|
|
|
return merged_chunks |
|
|
|
def merge_chunks_reverse(self, chunks): |
|
chunks_reverse = [] |
|
for chunk in chunks[::-1]: |
|
chunks_reverse.append(chunk[::-1]) |
|
|
|
merged_chunks = [] |
|
buffer = "" |
|
|
|
for chunk in chunks_reverse: |
|
buffer += chunk |
|
if len(buffer) > self.min_length: |
|
merged_chunks.append(buffer) |
|
buffer = "" |
|
|
|
|
|
if buffer: |
|
merged_chunks.append(buffer) |
|
|
|
res_merged_chunks = [] |
|
for chunk in merged_chunks[::-1]: |
|
res_merged_chunks.append(chunk[::-1]) |
|
|
|
return res_merged_chunks |
|
|
|
def split_text(self, text): |
|
words = self.split_space(text) |
|
|
|
|
|
splitted_chunks = [] |
|
buffer = [] |
|
|
|
for word in words: |
|
buffer.append(word) |
|
merged_text = ''.join(buffer) |
|
|
|
|
|
if len(merged_text) > self.max_length: |
|
|
|
buffer.pop() |
|
splitted_chunks.append(''.join(buffer)) |
|
buffer = [''+word] |
|
|
|
|
|
if buffer: |
|
splitted_chunks.append(''.join(buffer)) |
|
|
|
return splitted_chunks |
|
|
|
def tokenize(self, text): |
|
splitted_chunks = [] |
|
|
|
sentences = self.split_text_into_sentences(text) |
|
for chunk in sentences: |
|
if len(chunk)>=self.max_length: |
|
splitted_chunks.extend(self.split_text(chunk)) |
|
else: |
|
splitted_chunks.append(chunk) |
|
merged_chunks = self.merge_chunks(splitted_chunks) |
|
merged_chunks = self.merge_chunks_reverse(merged_chunks) |
|
|
|
return merged_chunks |
|
|
|
def split_space(self, text): |
|
split_text = re.split(r'(\s+)', text) |
|
filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])] |
|
return filtered_text |
|
|
|
def overlap(self, chunks): |
|
if not chunks: |
|
return [] |
|
if self.roll: |
|
chunks = [chunks[-1]] + chunks + [chunks[0]] |
|
res = [] |
|
total_idx = 0 |
|
for chunk_idx in range(len(chunks)-1): |
|
chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1] |
|
chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b) |
|
chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap |
|
for overlap_idx in range(self.n_overlap): |
|
chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx]) |
|
overlap = chunk_a_overlap+chunk_b_overlap |
|
start = total_idx+len(chunk_a_past) |
|
end = start + len(overlap) |
|
res.append((start, end, overlap)) |
|
total_idx += len(chunk_a) |
|
res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1])) |
|
|
|
return res |
|
|
|
def decode_overlap(self, chunks): |
|
if not chunks: |
|
return "" |
|
|
|
|
|
max_length = max(end for _, end, _ in chunks) |
|
|
|
|
|
index_char_map = {i: [] for i in range(max_length)} |
|
|
|
|
|
for start, end, chunk in chunks: |
|
for i, char in enumerate(chunk): |
|
index = start + i |
|
if index < max_length: |
|
index_char_map[index].append(char) |
|
|
|
|
|
reconstructed_text = [] |
|
for i in range(max_length): |
|
most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0] |
|
reconstructed_text.append(most_common_char) |
|
res = "".join(reconstructed_text) |
|
if self.roll: |
|
res = res[len(chunks[0][2]):-len(chunks[-1][2])] |
|
|
|
return res |