File size: 6,257 Bytes
4b05f4d
 
 
87a6fe0
4b05f4d
 
 
 
 
 
 
 
 
5467ee5
4b05f4d
 
 
 
 
 
5467ee5
4b05f4d
 
 
 
 
 
 
 
 
 
5467ee5
4b05f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5467ee5
4b05f4d
 
5467ee5
4b05f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5467ee5
4b05f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5467ee5
4b05f4d
 
5467ee5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import re

from collections import Counter
from transformers import PretrainedConfig, PreTrainedModel

class SentenceTokenizerConfig(PretrainedConfig):
    model_type = "sentence_tokenizer"
    def __init__(
        self,
        min_length=32,
        max_length=64,
        n_overlap=3,
        roll=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.min_length = min_length
        self.max_length = max_length
        self.n_overlap = n_overlap
        self.roll = roll

class SentenceTokenizer(PreTrainedModel):
    config_class = SentenceTokenizerConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.temp_module = torch.nn.Parameter(torch.ones(1))
        self.min_length = config.min_length
        self.max_length = config.max_length
        self.n_overlap = config.n_overlap
        self.roll = config.roll

    def split_text_into_sentences(self, text):
        split_text = re.split(r'([^가-힣] )', text)
        split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else [])

        return split_text

    def merge_chunks(self, chunks):
        merged_chunks = []
        buffer = ""

        for chunk in chunks:
            buffer += chunk
            if len(buffer) > self.min_length:  # If buffer meets the min length, finalize it
                merged_chunks.append(buffer)
                buffer = ""

        # Add any remaining buffer as the last chunk
        if buffer:
            merged_chunks.append(buffer)

        return merged_chunks

    def merge_chunks_reverse(self, chunks):
        chunks_reverse = []
        for chunk in chunks[::-1]:
            chunks_reverse.append(chunk[::-1])
            
        merged_chunks = []
        buffer = ""

        for chunk in chunks_reverse:
            buffer += chunk
            if len(buffer) > self.min_length:  # If buffer meets the min length, finalize it
                merged_chunks.append(buffer)
                buffer = ""

        # Add any remaining buffer as the last chunk
        if buffer:
            merged_chunks.append(buffer)

        res_merged_chunks = []
        for chunk in merged_chunks[::-1]:
            res_merged_chunks.append(chunk[::-1])

        return res_merged_chunks
        
    def split_text(self, text):
        words = self.split_space(text)
        
        # Step 2: Greedily merge words until the length of the merged text is shorter than max_length
        splitted_chunks = []
        buffer = []
        
        for word in words:
            buffer.append(word)  # Add the word to the buffer
            merged_text = ''.join(buffer)
            
            # If the merged text exceeds max_length, push the current buffer to the result
            if len(merged_text) > self.max_length:
                # Remove the last added word and save the current buffer as a chunk
                buffer.pop()
                splitted_chunks.append(''.join(buffer))
                buffer = [''+word]  # Start a new buffer with the current word
    
        # Step 3: Append the left over buffer
        if buffer:
            splitted_chunks.append(''.join(buffer))
    
        return splitted_chunks

    def tokenize(self, text):
        splitted_chunks = []
        # Step 1: Split text into sentences
        sentences = self.split_text_into_sentences(text)
        for chunk in sentences:
            if len(chunk)>=self.max_length:
                splitted_chunks.extend(self.split_text(chunk))
            else:
                splitted_chunks.append(chunk)
        merged_chunks = self.merge_chunks(splitted_chunks)
        merged_chunks = self.merge_chunks_reverse(merged_chunks)

        return merged_chunks

    def split_space(self, text):
        split_text = re.split(r'(\s+)', text)  # Keep spaces as part of the split parts
        filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])] 
        return filtered_text
    
    def overlap(self, chunks):
        if not chunks:
            return []
        if self.roll:
            chunks = [chunks[-1]] + chunks + [chunks[0]]
        res = []
        total_idx = 0
        for chunk_idx in range(len(chunks)-1):
            chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1]
            chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b)
            chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap
            for overlap_idx in range(self.n_overlap):
                chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx])
                overlap = chunk_a_overlap+chunk_b_overlap
                start = total_idx+len(chunk_a_past)
                end = start + len(overlap)
                res.append((start, end, overlap))
            total_idx += len(chunk_a)
        res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1]))

        return res

    def decode_overlap(self, chunks):
        if not chunks:
            return ""
        
        # Determine total length based on the largest end index
        max_length = max(end for _, end, _ in chunks)
        
        # Dictionary to store characters at each index
        index_char_map = {i: [] for i in range(max_length)}
        
        # Populate index_char_map with characters from chunks
        for start, end, chunk in chunks:
            for i, char in enumerate(chunk):
                index = start + i
                if index < max_length:
                    index_char_map[index].append(char)
        
        # Reconstruct text using majority vote
        reconstructed_text = []
        for i in range(max_length):
            most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0]
            reconstructed_text.append(most_common_char)
        res = "".join(reconstructed_text)
        if self.roll:
            res = res[len(chunks[0][2]):-len(chunks[-1][2])]
    
        return res