File size: 6,581 Bytes
0ad5c1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# improved_train_tokenizer_v2.py
import os
import sys
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors, normalizers
from transformers import PreTrainedTokenizerFast
# --- Configuration ---
TRAIN_FILES = ["improved_sentences.txt"] # Use the preprocessed file
VOCAB_SIZE = 32000
SPECIAL_TOKENS = ["<pad>", "<unk>", "<s>", "</s>", "<mask>"]
OUTPUT_DIR = "./improved_tokenizer_v2"
# --- Input File Check ---
if not TRAIN_FILES or not os.path.exists(TRAIN_FILES[0]):
print(f"Error: Training file '{TRAIN_FILES[0]}' not found.")
sys.exit(1)
print(f"Starting tokenizer training...")
print(f"Training file(s): {TRAIN_FILES}")
print(f"Target vocab size: {VOCAB_SIZE}")
print(f"Output directory: {OUTPUT_DIR}")
# --- Initialize Tokenizer ---
# We'll use ByteLevel BPE with proper whitespace handling
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
# --- Set Normalizer ---
# This helps standardize the text before tokenization
tokenizer.normalizer = normalizers.Sequence([
normalizers.NFC(), # Unicode normalization
normalizers.Replace(r"\s+", " ") # Replace multiple spaces with a single space
])
# --- Set Pre-tokenizer ---
# This is critical for handling whitespace correctly
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True) # Back to True for proper space handling
print(f"Using pre-tokenizer: ByteLevel(add_prefix_space=True)")
# --- Set Decoder ---
tokenizer.decoder = decoders.ByteLevel()
print(f"Using decoder: {tokenizer.decoder.__class__.__name__}")
# --- Define Trainer ---
trainer = trainers.BpeTrainer(
vocab_size=VOCAB_SIZE,
special_tokens=SPECIAL_TOKENS,
show_progress=True,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)
# --- Train Tokenizer ---
print("\nTraining the tokenizer model (this might take a while)...")
try:
tokenizer.train(files=TRAIN_FILES, trainer=trainer)
print("Training completed successfully.")
except Exception as e:
print(f"\nError during tokenizer training: {e}")
sys.exit(1)
# --- Add Post-processor ---
tokenizer.post_processor = processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> $B </s>",
special_tokens=[
("<s>", tokenizer.token_to_id("<s>")),
("</s>", tokenizer.token_to_id("</s>")),
],
)
# --- Save Core Tokenizer ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
tokenizer_path = os.path.join(OUTPUT_DIR, "tokenizer.json")
try:
tokenizer.save(tokenizer_path)
print(f"\nCore tokenizer saved to: {tokenizer_path}")
except Exception as e:
print(f"Error saving core tokenizer: {e}")
sys.exit(1)
# --- Create and Save HF Wrapper ---
print("\nWrapping tokenizer with PreTrainedTokenizerFast...")
try:
hf_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_path,
unk_token="<unk>",
pad_token="<pad>",
cls_token="<s>",
sep_token="</s>",
mask_token="<mask>",
add_prefix_space=True # Match the pre-tokenizer setting
)
hf_tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Hugging Face compatible tokenizer files saved to: {OUTPUT_DIR}")
except Exception as e:
print(f"Error saving Hugging Face tokenizer: {e}")
sys.exit(1)
# --- Verification Step ---
print("\n--- Verification ---")
try:
print(f"Loading tokenizer for verification from: {OUTPUT_DIR}")
loaded_hf_tokenizer = PreTrainedTokenizerFast.from_pretrained(OUTPUT_DIR)
# Test multiple cases, especially those starting with periods or spaces
test_cases = [
"Simple sentence.",
" Sentence starting with space.",
"Sentence. Another sentence.",
". Sentence starting with period.",
"Word.Word",
"The quick brown fox jumps over the lazy dog."
]
print("\n=== Testing with new tokenizer ===")
for i, text in enumerate(test_cases):
print(f"\nTest {i+1}: '{text}'")
tokens = loaded_hf_tokenizer.tokenize(text)
print(f"Tokens: {tokens}")
encoded = loaded_hf_tokenizer.encode(text, add_special_tokens=True)
decoded = loaded_hf_tokenizer.decode(encoded, skip_special_tokens=True)
print(f"Encoded: {encoded}")
print(f"Decoded: '{decoded}'")
# Check if tokenization properly preserves content
if text.strip() == decoded.strip():
print("✓ Encoding/decoding preserved text content")
else:
print(f"⚠ Warning: Text content changed during encoding/decoding")
print(f" Original: '{text}'")
print(f" Decoded: '{decoded}'")
# Check first token distributions
print("\n=== First Position Token Analysis ===")
print("Analyzing first token after <s> for potential bias...")
# Simplified analysis of first token (just for demonstration)
from collections import Counter
first_token_counter = Counter()
with open(TRAIN_FILES[0], 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 100: # Just check first 100 lines
break
line = line.strip()
if not line:
continue
encoded = loaded_hf_tokenizer.encode(line, add_special_tokens=True)
if len(encoded) > 1: # Make sure there's at least one token after <s>
first_token_id = encoded[1]
first_token_counter[first_token_id] += 1
total = sum(first_token_counter.values())
if total > 0:
print(f"\nTop 5 tokens at first position (after <s>) from {total} samples:")
for token_id, count in first_token_counter.most_common(5):
token_text = loaded_hf_tokenizer.decode([token_id])
percentage = (count / total) * 100
print(f"Token: '{token_text}' (ID: {token_id}) | Count: {count} | {percentage:.2f}%")
# Specifically check period token
period_id = loaded_hf_tokenizer.encode('.', add_special_tokens=False)[0]
period_count = first_token_counter.get(period_id, 0)
period_percentage = (period_count / total) * 100 if total > 0 else 0
print(f"\nPeriod token ('.', ID: {period_id}) at first position: {period_count} times ({period_percentage:.2f}%)")
except Exception as e:
print(f"Error during verification: {e}")
print("\n--- Tokenizer training script finished ---") |