jwengr commited on
Commit
4b05f4d
·
verified ·
1 Parent(s): c528172

Upload folder using huggingface_hub

Browse files
__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+
modeling_sentence_tokenizer.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+
6
+ class SentenceTokenizerConfig(PretrainedConfig):
7
+ model_type = "sentence_tokenizer"
8
+ def __init__(
9
+ self,
10
+ min_length=32,
11
+ max_length=64,
12
+ n_overlap=3,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.min_length = min_length
17
+ self.max_length = max_length
18
+ self.n_overlap = n_overlap
19
+
20
+ class SentenceTokenizer(PreTrainedModel):
21
+ config_class = SentenceTokenizerConfig
22
+
23
+ def __init__(self, config):
24
+ super().__init__(config)
25
+ self.temp_module = torch.nn.Parameter(torch.ones(1))
26
+ self.min_length = config.min_length
27
+ self.max_length = config.max_length
28
+ self.n_overlap = config.n_overlap
29
+
30
+ def split_text_into_sentences(self, text):
31
+ split_text = re.split(r'([^가-힣] )', text)
32
+ split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else [])
33
+
34
+ return split_text
35
+
36
+ def merge_chunks(self, chunks):
37
+ merged_chunks = []
38
+ buffer = ""
39
+
40
+ for chunk in chunks:
41
+ buffer += chunk
42
+ if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
43
+ merged_chunks.append(buffer)
44
+ buffer = ""
45
+
46
+ # Add any remaining buffer as the last chunk
47
+ if buffer:
48
+ merged_chunks.append(buffer)
49
+
50
+ return merged_chunks
51
+
52
+ def merge_chunks_reverse(self, chunks):
53
+ chunks_reverse = []
54
+ for chunk in chunks[::-1]:
55
+ chunks_reverse.append(chunk[::-1])
56
+
57
+ merged_chunks = []
58
+ buffer = ""
59
+
60
+ for chunk in chunks_reverse:
61
+ buffer += chunk
62
+ if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
63
+ merged_chunks.append(buffer)
64
+ buffer = ""
65
+
66
+ # Add any remaining buffer as the last chunk
67
+ if buffer:
68
+ merged_chunks.append(buffer)
69
+
70
+ res_merged_chunks = []
71
+ for chunk in merged_chunks[::-1]:
72
+ res_merged_chunks.append(chunk[::-1])
73
+
74
+ return res_merged_chunks
75
+
76
+ def split_text(self, text):
77
+ words = self.split_space(text)
78
+
79
+ # Step 2: Greedily merge words until the length of the merged text is shorter than max_length
80
+ splitted_chunks = []
81
+ buffer = []
82
+
83
+ for word in words:
84
+ buffer.append(word) # Add the word to the buffer
85
+ merged_text = ''.join(buffer)
86
+
87
+ # If the merged text exceeds max_length, push the current buffer to the result
88
+ if len(merged_text) > self.max_length:
89
+ # Remove the last added word and save the current buffer as a chunk
90
+ buffer.pop()
91
+ splitted_chunks.append(''.join(buffer))
92
+ buffer = [''+word] # Start a new buffer with the current word
93
+
94
+ # Step 3: Append the left over buffer
95
+ if buffer:
96
+ splitted_chunks.append(''.join(buffer))
97
+
98
+ return splitted_chunks
99
+
100
+ def tokenize(self, text):
101
+ splitted_chunks = []
102
+ # Step 1: Split text into sentences
103
+ sentences = self.split_text_into_sentences(text)
104
+ for chunk in sentences:
105
+ if len(chunk)>=self.max_length:
106
+ splitted_chunks.extend(self.split_text(chunk))
107
+ else:
108
+ splitted_chunks.append(chunk)
109
+ merged_chunks = self.merge_chunks(splitted_chunks)
110
+ merged_chunks = self.merge_chunks_reverse(merged_chunks)
111
+
112
+ return merged_chunks
113
+
114
+ def split_space(self, text):
115
+ split_text = re.split(r'(\s+)', text) # Keep spaces as part of the split parts
116
+ filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])]
117
+ return filtered_text
118
+
119
+ def overlap(self, chunks, roll=False):
120
+ if not chunks:
121
+ return []
122
+ if roll==True:
123
+ chunks = [chunks[-1]] + chunks + [chunks[0]]
124
+ res = []
125
+ total_idx = 0
126
+ for chunk_idx in range(len(chunks)-1):
127
+ chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1]
128
+ chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b)
129
+ chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap
130
+ for overlap_idx in range(self.n_overlap):
131
+ chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx])
132
+ overlap = chunk_a_overlap+chunk_b_overlap
133
+ start = total_idx+len(chunk_a_past)
134
+ end = start + len(overlap)
135
+ res.append((start, end, overlap))
136
+ total_idx += len(chunk_a)
137
+ res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1]))
138
+
139
+ return res
140
+
141
+ def decode_overlap(self, chunks, roll=False):
142
+ if not chunks:
143
+ return ""
144
+
145
+ # Determine total length based on the largest end index
146
+ max_length = max(end for _, end, _ in chunks)
147
+
148
+ # Dictionary to store characters at each index
149
+ index_char_map = {i: [] for i in range(max_length)}
150
+
151
+ # Populate index_char_map with characters from chunks
152
+ for start, end, chunk in chunks:
153
+ for i, char in enumerate(chunk):
154
+ index = start + i
155
+ if index < max_length:
156
+ index_char_map[index].append(char)
157
+
158
+ # Reconstruct text using majority vote
159
+ reconstructed_text = []
160
+ for i in range(max_length):
161
+ most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0]
162
+ reconstructed_text.append(most_common_char)
163
+ res = "".join(reconstructed_text)
164
+ if roll==True:
165
+ res = res[len(chunks[0][2]):-len(chunks[-1][2])]
166
+
167
+ return res
sentence_tokenizer/config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SentenceTokenizer"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_sentence_tokenizer.SentenceTokenizerConfig",
7
+ "AutoModel": [
8
+ "modeling_sentence_tokenizer.SentenceTokenizer"
9
+ },
10
+ "max_length": 64,
11
+ "min_length": 32,
12
+ "model_type": "sentence_tokenizer",
13
+ "n_overlap": 3,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.48.0"
16
+ }
sentence_tokenizer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2b29affbce2da50bace6c60697df30b796ff62cba44ab8755d6b264abebc0de
3
+ size 108
sentence_tokenizer/modeling_sentence_tokenizer.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+
6
+ class SentenceTokenizerConfig(PretrainedConfig):
7
+ model_type = "sentence_tokenizer"
8
+ def __init__(
9
+ self,
10
+ min_length=32,
11
+ max_length=64,
12
+ n_overlap=3,
13
+ **kwargs
14
+ ):
15
+ super().__init__(**kwargs)
16
+ self.min_length = min_length
17
+ self.max_length = max_length
18
+ self.n_overlap = n_overlap
19
+
20
+ class SentenceTokenizer(PreTrainedModel):
21
+ config_class = SentenceTokenizerConfig
22
+
23
+ def __init__(self, config):
24
+ super().__init__(config)
25
+ self.temp_module = torch.nn.Parameter(torch.ones(1))
26
+ self.min_length = config.min_length
27
+ self.max_length = config.max_length
28
+ self.n_overlap = config.n_overlap
29
+
30
+ def split_text_into_sentences(self, text):
31
+ split_text = re.split(r'([^가-힣] )', text)
32
+ split_text = [split_text[i] + split_text[i + 1] for i in range(0, len(split_text) - 1, 2)] + ([split_text[-1]] if len(split_text) % 2 != 0 else [])
33
+
34
+ return split_text
35
+
36
+ def merge_chunks(self, chunks):
37
+ merged_chunks = []
38
+ buffer = ""
39
+
40
+ for chunk in chunks:
41
+ buffer += chunk
42
+ if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
43
+ merged_chunks.append(buffer)
44
+ buffer = ""
45
+
46
+ # Add any remaining buffer as the last chunk
47
+ if buffer:
48
+ merged_chunks.append(buffer)
49
+
50
+ return merged_chunks
51
+
52
+ def merge_chunks_reverse(self, chunks):
53
+ chunks_reverse = []
54
+ for chunk in chunks[::-1]:
55
+ chunks_reverse.append(chunk[::-1])
56
+
57
+ merged_chunks = []
58
+ buffer = ""
59
+
60
+ for chunk in chunks_reverse:
61
+ buffer += chunk
62
+ if len(buffer) > self.min_length: # If buffer meets the min length, finalize it
63
+ merged_chunks.append(buffer)
64
+ buffer = ""
65
+
66
+ # Add any remaining buffer as the last chunk
67
+ if buffer:
68
+ merged_chunks.append(buffer)
69
+
70
+ res_merged_chunks = []
71
+ for chunk in merged_chunks[::-1]:
72
+ res_merged_chunks.append(chunk[::-1])
73
+
74
+ return res_merged_chunks
75
+
76
+ def split_text(self, text):
77
+ words = self.split_space(text)
78
+
79
+ # Step 2: Greedily merge words until the length of the merged text is shorter than max_length
80
+ splitted_chunks = []
81
+ buffer = []
82
+
83
+ for word in words:
84
+ buffer.append(word) # Add the word to the buffer
85
+ merged_text = ''.join(buffer)
86
+
87
+ # If the merged text exceeds max_length, push the current buffer to the result
88
+ if len(merged_text) > self.max_length:
89
+ # Remove the last added word and save the current buffer as a chunk
90
+ buffer.pop()
91
+ splitted_chunks.append(''.join(buffer))
92
+ buffer = [''+word] # Start a new buffer with the current word
93
+
94
+ # Step 3: Append the left over buffer
95
+ if buffer:
96
+ splitted_chunks.append(''.join(buffer))
97
+
98
+ return splitted_chunks
99
+
100
+ def tokenize(self, text):
101
+ splitted_chunks = []
102
+ # Step 1: Split text into sentences
103
+ sentences = self.split_text_into_sentences(text)
104
+ for chunk in sentences:
105
+ if len(chunk)>=self.max_length:
106
+ splitted_chunks.extend(self.split_text(chunk))
107
+ else:
108
+ splitted_chunks.append(chunk)
109
+ merged_chunks = self.merge_chunks(splitted_chunks)
110
+ merged_chunks = self.merge_chunks_reverse(merged_chunks)
111
+
112
+ return merged_chunks
113
+
114
+ def split_space(self, text):
115
+ split_text = re.split(r'(\s+)', text) # Keep spaces as part of the split parts
116
+ filtered_text = [s + sp for s, sp in zip(split_text[::2], split_text[1::2] + [''])]
117
+ return filtered_text
118
+
119
+ def overlap(self, chunks, roll=False):
120
+ if not chunks:
121
+ return []
122
+ if roll==True:
123
+ chunks = [chunks[-1]] + chunks + [chunks[0]]
124
+ res = []
125
+ total_idx = 0
126
+ for chunk_idx in range(len(chunks)-1):
127
+ chunk_a, chunk_b = chunks[chunk_idx], chunks[chunk_idx+1]
128
+ chunk_a_words, chunk_b_words = self.split_space(chunk_a), self.split_space(chunk_b)
129
+ chunk_a_overlap_length, chunk_b_overlap_length = len(chunk_a_words)//self.n_overlap, len(chunk_b_words)//self.n_overlap
130
+ for overlap_idx in range(self.n_overlap):
131
+ chunk_a_past, chunk_a_overlap, chunk_b_overlap = ''.join(chunk_a_words[:chunk_a_overlap_length*overlap_idx]), ''.join(chunk_a_words[chunk_a_overlap_length*overlap_idx:]), ''.join(chunk_b_words[:chunk_b_overlap_length*overlap_idx])
132
+ overlap = chunk_a_overlap+chunk_b_overlap
133
+ start = total_idx+len(chunk_a_past)
134
+ end = start + len(overlap)
135
+ res.append((start, end, overlap))
136
+ total_idx += len(chunk_a)
137
+ res.append((total_idx, total_idx+len(chunks[-1]), chunks[-1]))
138
+
139
+ return res
140
+
141
+ def decode_overlap(self, chunks, roll=False):
142
+ if not chunks:
143
+ return ""
144
+
145
+ # Determine total length based on the largest end index
146
+ max_length = max(end for _, end, _ in chunks)
147
+
148
+ # Dictionary to store characters at each index
149
+ index_char_map = {i: [] for i in range(max_length)}
150
+
151
+ # Populate index_char_map with characters from chunks
152
+ for start, end, chunk in chunks:
153
+ for i, char in enumerate(chunk):
154
+ index = start + i
155
+ if index < max_length:
156
+ index_char_map[index].append(char)
157
+
158
+ # Reconstruct text using majority vote
159
+ reconstructed_text = []
160
+ for i in range(max_length):
161
+ most_common_char, _ = Counter(index_char_map[i]).most_common(1)[0]
162
+ reconstructed_text.append(most_common_char)
163
+ res = "".join(reconstructed_text)
164
+ if roll==True:
165
+ res = res[len(chunks[0][2]):-len(chunks[-1][2])]
166
+
167
+ return res