File size: 1,300 Bytes
			
			| e6b57de 5159d00 37293dc e6b57de 5159d00 2bc1a5b 5159d00 2bc1a5b 5159d00 e6b57de 5159d00 2bc1a5b 5159d00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | """Module for tokenization utilities"""
import logging
from termcolor import colored
def check_dataset_labels(dataset, tokenizer):
    # the dataset is already shuffled, so let's just check the first 5 elements
    for idx in range(5):
        check_example_labels(dataset[idx], tokenizer)
def check_example_labels(example, tokenizer):
    # Get the input_ids, labels, and attention_mask from the dataset
    input_ids = example["input_ids"]
    labels = example["labels"]
    attention_mask = example["attention_mask"]
    # You can compare the input_ids and labels element-wise
    # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
    colored_tokens = []
    for _, (input_id, label_id, mask) in enumerate(
        zip(input_ids, labels, attention_mask)
    ):
        decoded_input_token = tokenizer.decode(input_id)
        # Choose the color based on whether the label has the ignore value or not
        color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
        colored_token = colored(decoded_input_token, color) + colored(
            f"({label_id}, {mask}, {input_id})", "white"
        )
        colored_tokens.append(colored_token)
    logging.info(" ".join(colored_tokens))
    logging.info("\n\n\n")
 |