gemma2-2b-kor-deobfuscation / sentence_tokenizer /modeling_sentence_tokenizer.py
jwengr's picture
Upload folder using huggingface_hub
5467ee5 verified
import torch
import re
from collections import Counter
from transformers import PretrainedConfig, PreTrainedModel
class SentenceTokenizerConfig(PretrainedConfig):
model_type = "sentence_tokenizer"
def __init__(
self,
min_length=32,
max_length=64,
n_overlap=3,
roll=False,
**kwargs
):
super().__init__(**kwargs)
self.min_length = min_length
self.max_length = max_length
self.n_overlap = n_overlap
self.roll = roll
class SentenceTokenizer(PreTrainedModel):
config_class = SentenceTokenizerConfig
def __init__(self, config):
super().__init__(config)
self.temp_module = torch.nn.Parameter(torch.ones(1))
self.min_length = config.min_length
self.max_length = config.max_length
self.n_overlap = config.n_overlap
self.roll = config.roll
def split_text_into_sentences(self, text):
split_text = re.split(r'([^가-힣] )', text)
split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else [])
return split_text
def merge_chunks(self, chunks):
merged_chunks = []
buffer = ""
for chunk in chunks:
buffer += chunk
if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
merged_chunks.append(buffer)
buffer = ""
# Add any remaining buffer as the last chunk
if buffer:
merged_chunks.append(buffer)
return merged_chunks
def merge_chunks_reverse(self, chunks):
chunks_reverse = []
for chunk in chunks[::-1]:
chunks_reverse.append(chunk[::-1])
merged_chunks = []
buffer = ""
for chunk in chunks_reverse:
buffer += chunk
if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
merged_chunks.append(buffer)
buffer = ""
# Add any remaining buffer as the last chunk
if buffer:
merged_chunks.append(buffer)
res_merged_chunks = []
for chunk in merged_chunks[::-1]:
res_merged_chunks.append(chunk[::-1])
return res_merged_chunks
def split_text(self, text):
words = self.split_space(text)
# Step 2: Greedily merge words until the length of the merged text is shorter than max_length
splitted_chunks = []
buffer = []
for word in words:
buffer.append(word) # Add the word to the buffer
merged_text = ''.join(buffer)
# If the merged text exceeds max_length, push the current buffer to the result
if len(merged_text) > self.max_length:
# Remove the last added word and save the current buffer as a chunk
buffer.pop()
splitted_chunks.append(''.join(buffer))
buffer = [''+word] # Start a new buffer with the current word
# Step 3: Append the left over buffer
if buffer:
splitted_chunks.append(''.join(buffer))
return splitted_chunks
def tokenize(self, text):
splitted_chunks = []
# Step 1: Split text into sentences
sentences = self.split_text_into_sentences(text)
for chunk in sentences:
if len(chunk)>=self.max_length:
splitted_chunks.extend(self.split_text(chunk))
else:
splitted_chunks.append(chunk)
merged_chunks = self.merge_chunks(splitted_chunks)
merged_chunks = self.merge_chunks_reverse(merged_chunks)
return merged_chunks
def split_space(self, text):
split_text = re.split(r'(\s+)', text) # Keep spaces as part of the split parts
filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])]
return filtered_text
def overlap(self, chunks):
if not chunks:
return []
if self.roll:
chunks = [chunks[-1]] + chunks + [chunks[0]]
res = []
total_idx = 0
for chunk_idx in range(len(chunks)-1):
chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1]
chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b)
chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap
for overlap_idx in range(self.n_overlap):
chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx])
overlap = chunk_a_overlap+chunk_b_overlap
start = total_idx+len(chunk_a_past)
end = start + len(overlap)
res.append((start, end, overlap))
total_idx += len(chunk_a)
res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1]))
return res
def decode_overlap(self, chunks):
if not chunks:
return ""
# Determine total length based on the largest end index
max_length = max(end for _, end, _ in chunks)
# Dictionary to store characters at each index
index_char_map = {i: [] for i in range(max_length)}
# Populate index_char_map with characters from chunks
for start, end, chunk in chunks:
for i, char in enumerate(chunk):
index = start + i
if index < max_length:
index_char_map[index].append(char)
# Reconstruct text using majority vote
reconstructed_text = []
for i in range(max_length):
most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0]
reconstructed_text.append(most_common_char)
res = "".join(reconstructed_text)
if self.roll:
res = res[len(chunks[0][2]):-len(chunks[-1][2])]
return res