dejanseo commited on
Commit
0ad5c1a
·
verified ·
1 Parent(s): a08869d

Upload train_tokenizer.py

Browse files
Files changed (1) hide show
  1. train_tokenizer.py +172 -0
train_tokenizer.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # improved_train_tokenizer_v2.py
2
+
3
+ import os
4
+ import sys
5
+ from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors, normalizers
6
+ from transformers import PreTrainedTokenizerFast
7
+
8
+ # --- Configuration ---
9
+ TRAIN_FILES = ["improved_sentences.txt"] # Use the preprocessed file
10
+ VOCAB_SIZE = 32000
11
+ SPECIAL_TOKENS = ["<pad>", "<unk>", "<s>", "</s>", "<mask>"]
12
+ OUTPUT_DIR = "./improved_tokenizer_v2"
13
+
14
+ # --- Input File Check ---
15
+ if not TRAIN_FILES or not os.path.exists(TRAIN_FILES[0]):
16
+ print(f"Error: Training file '{TRAIN_FILES[0]}' not found.")
17
+ sys.exit(1)
18
+
19
+ print(f"Starting tokenizer training...")
20
+ print(f"Training file(s): {TRAIN_FILES}")
21
+ print(f"Target vocab size: {VOCAB_SIZE}")
22
+ print(f"Output directory: {OUTPUT_DIR}")
23
+
24
+ # --- Initialize Tokenizer ---
25
+ # We'll use ByteLevel BPE with proper whitespace handling
26
+ tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
27
+
28
+ # --- Set Normalizer ---
29
+ # This helps standardize the text before tokenization
30
+ tokenizer.normalizer = normalizers.Sequence([
31
+ normalizers.NFC(), # Unicode normalization
32
+ normalizers.Replace(r"\s+", " ") # Replace multiple spaces with a single space
33
+ ])
34
+
35
+ # --- Set Pre-tokenizer ---
36
+ # This is critical for handling whitespace correctly
37
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True) # Back to True for proper space handling
38
+ print(f"Using pre-tokenizer: ByteLevel(add_prefix_space=True)")
39
+
40
+ # --- Set Decoder ---
41
+ tokenizer.decoder = decoders.ByteLevel()
42
+ print(f"Using decoder: {tokenizer.decoder.__class__.__name__}")
43
+
44
+ # --- Define Trainer ---
45
+ trainer = trainers.BpeTrainer(
46
+ vocab_size=VOCAB_SIZE,
47
+ special_tokens=SPECIAL_TOKENS,
48
+ show_progress=True,
49
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
50
+ )
51
+
52
+ # --- Train Tokenizer ---
53
+ print("\nTraining the tokenizer model (this might take a while)...")
54
+ try:
55
+ tokenizer.train(files=TRAIN_FILES, trainer=trainer)
56
+ print("Training completed successfully.")
57
+ except Exception as e:
58
+ print(f"\nError during tokenizer training: {e}")
59
+ sys.exit(1)
60
+
61
+ # --- Add Post-processor ---
62
+ tokenizer.post_processor = processors.TemplateProcessing(
63
+ single="<s> $A </s>",
64
+ pair="<s> $A </s> $B </s>",
65
+ special_tokens=[
66
+ ("<s>", tokenizer.token_to_id("<s>")),
67
+ ("</s>", tokenizer.token_to_id("</s>")),
68
+ ],
69
+ )
70
+
71
+ # --- Save Core Tokenizer ---
72
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
73
+ tokenizer_path = os.path.join(OUTPUT_DIR, "tokenizer.json")
74
+ try:
75
+ tokenizer.save(tokenizer_path)
76
+ print(f"\nCore tokenizer saved to: {tokenizer_path}")
77
+ except Exception as e:
78
+ print(f"Error saving core tokenizer: {e}")
79
+ sys.exit(1)
80
+
81
+ # --- Create and Save HF Wrapper ---
82
+ print("\nWrapping tokenizer with PreTrainedTokenizerFast...")
83
+ try:
84
+ hf_tokenizer = PreTrainedTokenizerFast(
85
+ tokenizer_file=tokenizer_path,
86
+ unk_token="<unk>",
87
+ pad_token="<pad>",
88
+ cls_token="<s>",
89
+ sep_token="</s>",
90
+ mask_token="<mask>",
91
+ add_prefix_space=True # Match the pre-tokenizer setting
92
+ )
93
+ hf_tokenizer.save_pretrained(OUTPUT_DIR)
94
+ print(f"Hugging Face compatible tokenizer files saved to: {OUTPUT_DIR}")
95
+ except Exception as e:
96
+ print(f"Error saving Hugging Face tokenizer: {e}")
97
+ sys.exit(1)
98
+
99
+ # --- Verification Step ---
100
+ print("\n--- Verification ---")
101
+ try:
102
+ print(f"Loading tokenizer for verification from: {OUTPUT_DIR}")
103
+ loaded_hf_tokenizer = PreTrainedTokenizerFast.from_pretrained(OUTPUT_DIR)
104
+
105
+ # Test multiple cases, especially those starting with periods or spaces
106
+ test_cases = [
107
+ "Simple sentence.",
108
+ " Sentence starting with space.",
109
+ "Sentence. Another sentence.",
110
+ ". Sentence starting with period.",
111
+ "Word.Word",
112
+ "The quick brown fox jumps over the lazy dog."
113
+ ]
114
+
115
+ print("\n=== Testing with new tokenizer ===")
116
+ for i, text in enumerate(test_cases):
117
+ print(f"\nTest {i+1}: '{text}'")
118
+ tokens = loaded_hf_tokenizer.tokenize(text)
119
+ print(f"Tokens: {tokens}")
120
+
121
+ encoded = loaded_hf_tokenizer.encode(text, add_special_tokens=True)
122
+ decoded = loaded_hf_tokenizer.decode(encoded, skip_special_tokens=True)
123
+ print(f"Encoded: {encoded}")
124
+ print(f"Decoded: '{decoded}'")
125
+
126
+ # Check if tokenization properly preserves content
127
+ if text.strip() == decoded.strip():
128
+ print("✓ Encoding/decoding preserved text content")
129
+ else:
130
+ print(f"⚠ Warning: Text content changed during encoding/decoding")
131
+ print(f" Original: '{text}'")
132
+ print(f" Decoded: '{decoded}'")
133
+
134
+ # Check first token distributions
135
+ print("\n=== First Position Token Analysis ===")
136
+ print("Analyzing first token after <s> for potential bias...")
137
+
138
+ # Simplified analysis of first token (just for demonstration)
139
+ from collections import Counter
140
+ first_token_counter = Counter()
141
+
142
+ with open(TRAIN_FILES[0], 'r', encoding='utf-8') as f:
143
+ for i, line in enumerate(f):
144
+ if i >= 100: # Just check first 100 lines
145
+ break
146
+ line = line.strip()
147
+ if not line:
148
+ continue
149
+
150
+ encoded = loaded_hf_tokenizer.encode(line, add_special_tokens=True)
151
+ if len(encoded) > 1: # Make sure there's at least one token after <s>
152
+ first_token_id = encoded[1]
153
+ first_token_counter[first_token_id] += 1
154
+
155
+ total = sum(first_token_counter.values())
156
+ if total > 0:
157
+ print(f"\nTop 5 tokens at first position (after <s>) from {total} samples:")
158
+ for token_id, count in first_token_counter.most_common(5):
159
+ token_text = loaded_hf_tokenizer.decode([token_id])
160
+ percentage = (count / total) * 100
161
+ print(f"Token: '{token_text}' (ID: {token_id}) | Count: {count} | {percentage:.2f}%")
162
+
163
+ # Specifically check period token
164
+ period_id = loaded_hf_tokenizer.encode('.', add_special_tokens=False)[0]
165
+ period_count = first_token_counter.get(period_id, 0)
166
+ period_percentage = (period_count / total) * 100 if total > 0 else 0
167
+ print(f"\nPeriod token ('.', ID: {period_id}) at first position: {period_count} times ({period_percentage:.2f}%)")
168
+
169
+ except Exception as e:
170
+ print(f"Error during verification: {e}")
171
+
172
+ print("\n--- Tokenizer training script finished ---")