File size: 6,581 Bytes
0ad5c1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# improved_train_tokenizer_v2.py

import os
import sys
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors, normalizers
from transformers import PreTrainedTokenizerFast

# --- Configuration ---
TRAIN_FILES = ["improved_sentences.txt"]  # Use the preprocessed file
VOCAB_SIZE = 32000
SPECIAL_TOKENS = ["<pad>", "<unk>", "<s>", "</s>", "<mask>"]
OUTPUT_DIR = "./improved_tokenizer_v2"

# --- Input File Check ---
if not TRAIN_FILES or not os.path.exists(TRAIN_FILES[0]):
    print(f"Error: Training file '{TRAIN_FILES[0]}' not found.")
    sys.exit(1)

print(f"Starting tokenizer training...")
print(f"Training file(s): {TRAIN_FILES}")
print(f"Target vocab size: {VOCAB_SIZE}")
print(f"Output directory: {OUTPUT_DIR}")

# --- Initialize Tokenizer ---
# We'll use ByteLevel BPE with proper whitespace handling
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))

# --- Set Normalizer ---
# This helps standardize the text before tokenization
tokenizer.normalizer = normalizers.Sequence([
    normalizers.NFC(),  # Unicode normalization
    normalizers.Replace(r"\s+", " ")  # Replace multiple spaces with a single space
])

# --- Set Pre-tokenizer ---
# This is critical for handling whitespace correctly
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)  # Back to True for proper space handling
print(f"Using pre-tokenizer: ByteLevel(add_prefix_space=True)")

# --- Set Decoder ---
tokenizer.decoder = decoders.ByteLevel()
print(f"Using decoder: {tokenizer.decoder.__class__.__name__}")

# --- Define Trainer ---
trainer = trainers.BpeTrainer(
    vocab_size=VOCAB_SIZE,
    special_tokens=SPECIAL_TOKENS,
    show_progress=True,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)

# --- Train Tokenizer ---
print("\nTraining the tokenizer model (this might take a while)...")
try:
    tokenizer.train(files=TRAIN_FILES, trainer=trainer)
    print("Training completed successfully.")
except Exception as e:
    print(f"\nError during tokenizer training: {e}")
    sys.exit(1)

# --- Add Post-processor ---
tokenizer.post_processor = processors.TemplateProcessing(
    single="<s> $A </s>",
    pair="<s> $A </s> $B </s>",
    special_tokens=[
        ("<s>", tokenizer.token_to_id("<s>")),
        ("</s>", tokenizer.token_to_id("</s>")),
    ],
)

# --- Save Core Tokenizer ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
tokenizer_path = os.path.join(OUTPUT_DIR, "tokenizer.json")
try:
    tokenizer.save(tokenizer_path)
    print(f"\nCore tokenizer saved to: {tokenizer_path}")
except Exception as e:
    print(f"Error saving core tokenizer: {e}")
    sys.exit(1)

# --- Create and Save HF Wrapper ---
print("\nWrapping tokenizer with PreTrainedTokenizerFast...")
try:
    hf_tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=tokenizer_path,
        unk_token="<unk>",
        pad_token="<pad>",
        cls_token="<s>",
        sep_token="</s>",
        mask_token="<mask>",
        add_prefix_space=True  # Match the pre-tokenizer setting
    )
    hf_tokenizer.save_pretrained(OUTPUT_DIR)
    print(f"Hugging Face compatible tokenizer files saved to: {OUTPUT_DIR}")
except Exception as e:
    print(f"Error saving Hugging Face tokenizer: {e}")
    sys.exit(1)

# --- Verification Step ---
print("\n--- Verification ---")
try:
    print(f"Loading tokenizer for verification from: {OUTPUT_DIR}")
    loaded_hf_tokenizer = PreTrainedTokenizerFast.from_pretrained(OUTPUT_DIR)

    # Test multiple cases, especially those starting with periods or spaces
    test_cases = [
        "Simple sentence.",
        " Sentence starting with space.",
        "Sentence. Another sentence.",
        ". Sentence starting with period.", 
        "Word.Word",
        "The quick brown fox jumps over the lazy dog."
    ]
    
    print("\n=== Testing with new tokenizer ===")
    for i, text in enumerate(test_cases):
        print(f"\nTest {i+1}: '{text}'")
        tokens = loaded_hf_tokenizer.tokenize(text)
        print(f"Tokens: {tokens}")
        
        encoded = loaded_hf_tokenizer.encode(text, add_special_tokens=True)
        decoded = loaded_hf_tokenizer.decode(encoded, skip_special_tokens=True)
        print(f"Encoded: {encoded}")
        print(f"Decoded: '{decoded}'")
        
        # Check if tokenization properly preserves content
        if text.strip() == decoded.strip():
            print("✓ Encoding/decoding preserved text content")
        else:
            print(f"⚠ Warning: Text content changed during encoding/decoding")
            print(f"  Original: '{text}'")
            print(f"  Decoded:  '{decoded}'")

    # Check first token distributions
    print("\n=== First Position Token Analysis ===")
    print("Analyzing first token after <s> for potential bias...")
    
    # Simplified analysis of first token (just for demonstration)
    from collections import Counter
    first_token_counter = Counter()
    
    with open(TRAIN_FILES[0], 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i >= 100:  # Just check first 100 lines
                break
            line = line.strip()
            if not line:
                continue
                
            encoded = loaded_hf_tokenizer.encode(line, add_special_tokens=True)
            if len(encoded) > 1:  # Make sure there's at least one token after <s>
                first_token_id = encoded[1]
                first_token_counter[first_token_id] += 1
    
    total = sum(first_token_counter.values())
    if total > 0:
        print(f"\nTop 5 tokens at first position (after <s>) from {total} samples:")
        for token_id, count in first_token_counter.most_common(5):
            token_text = loaded_hf_tokenizer.decode([token_id])
            percentage = (count / total) * 100
            print(f"Token: '{token_text}' (ID: {token_id}) | Count: {count} | {percentage:.2f}%")
            
        # Specifically check period token
        period_id = loaded_hf_tokenizer.encode('.', add_special_tokens=False)[0]
        period_count = first_token_counter.get(period_id, 0)
        period_percentage = (period_count / total) * 100 if total > 0 else 0
        print(f"\nPeriod token ('.', ID: {period_id}) at first position: {period_count} times ({period_percentage:.2f}%)")
    
except Exception as e:
    print(f"Error during verification: {e}")

print("\n--- Tokenizer training script finished ---")