File size: 6,257 Bytes
4b05f4d 87a6fe0 4b05f4d 5467ee5 4b05f4d 5467ee5 4b05f4d 5467ee5 4b05f4d 5467ee5 4b05f4d 5467ee5 4b05f4d 5467ee5 4b05f4d 5467ee5 4b05f4d 5467ee5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import re
from collections import Counter
from transformers import PretrainedConfig, PreTrainedModel
class SentenceTokenizerConfig(PretrainedConfig):
model_type = "sentence_tokenizer"
def __init__(
self,
min_length=32,
max_length=64,
n_overlap=3,
roll=False,
**kwargs
):
super().__init__(**kwargs)
self.min_length = min_length
self.max_length = max_length
self.n_overlap = n_overlap
self.roll = roll
class SentenceTokenizer(PreTrainedModel):
config_class = SentenceTokenizerConfig
def __init__(self, config):
super().__init__(config)
self.temp_module = torch.nn.Parameter(torch.ones(1))
self.min_length = config.min_length
self.max_length = config.max_length
self.n_overlap = config.n_overlap
self.roll = config.roll
def split_text_into_sentences(self, text):
split_text = re.split(r'([^가-힣] )', text)
split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else [])
return split_text
def merge_chunks(self, chunks):
merged_chunks = []
buffer = ""
for chunk in chunks:
buffer += chunk
if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
merged_chunks.append(buffer)
buffer = ""
# Add any remaining buffer as the last chunk
if buffer:
merged_chunks.append(buffer)
return merged_chunks
def merge_chunks_reverse(self, chunks):
chunks_reverse = []
for chunk in chunks[::-1]:
chunks_reverse.append(chunk[::-1])
merged_chunks = []
buffer = ""
for chunk in chunks_reverse:
buffer += chunk
if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
merged_chunks.append(buffer)
buffer = ""
# Add any remaining buffer as the last chunk
if buffer:
merged_chunks.append(buffer)
res_merged_chunks = []
for chunk in merged_chunks[::-1]:
res_merged_chunks.append(chunk[::-1])
return res_merged_chunks
def split_text(self, text):
words = self.split_space(text)
# Step 2: Greedily merge words until the length of the merged text is shorter than max_length
splitted_chunks = []
buffer = []
for word in words:
buffer.append(word) # Add the word to the buffer
merged_text = ''.join(buffer)
# If the merged text exceeds max_length, push the current buffer to the result
if len(merged_text) > self.max_length:
# Remove the last added word and save the current buffer as a chunk
buffer.pop()
splitted_chunks.append(''.join(buffer))
buffer = [''+word] # Start a new buffer with the current word
# Step 3: Append the left over buffer
if buffer:
splitted_chunks.append(''.join(buffer))
return splitted_chunks
def tokenize(self, text):
splitted_chunks = []
# Step 1: Split text into sentences
sentences = self.split_text_into_sentences(text)
for chunk in sentences:
if len(chunk)>=self.max_length:
splitted_chunks.extend(self.split_text(chunk))
else:
splitted_chunks.append(chunk)
merged_chunks = self.merge_chunks(splitted_chunks)
merged_chunks = self.merge_chunks_reverse(merged_chunks)
return merged_chunks
def split_space(self, text):
split_text = re.split(r'(\s+)', text) # Keep spaces as part of the split parts
filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])]
return filtered_text
def overlap(self, chunks):
if not chunks:
return []
if self.roll:
chunks = [chunks[-1]] + chunks + [chunks[0]]
res = []
total_idx = 0
for chunk_idx in range(len(chunks)-1):
chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1]
chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b)
chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap
for overlap_idx in range(self.n_overlap):
chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx])
overlap = chunk_a_overlap+chunk_b_overlap
start = total_idx+len(chunk_a_past)
end = start + len(overlap)
res.append((start, end, overlap))
total_idx += len(chunk_a)
res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1]))
return res
def decode_overlap(self, chunks):
if not chunks:
return ""
# Determine total length based on the largest end index
max_length = max(end for _, end, _ in chunks)
# Dictionary to store characters at each index
index_char_map = {i: [] for i in range(max_length)}
# Populate index_char_map with characters from chunks
for start, end, chunk in chunks:
for i, char in enumerate(chunk):
index = start + i
if index < max_length:
index_char_map[index].append(char)
# Reconstruct text using majority vote
reconstructed_text = []
for i in range(max_length):
most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0]
reconstructed_text.append(most_common_char)
res = "".join(reconstructed_text)
if self.roll:
res = res[len(chunks[0][2]):-len(chunks[-1][2])]
return res |