Tom Jobbins
Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511)
48434be
unverified
| """Module for tokenization utilities""" | |
| import logging | |
| from termcolor import colored | |
| LOG = logging.getLogger("axolotl") | |
| def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): | |
| # the dataset is already shuffled, so let's just check the first 5 elements | |
| for idx in range(num_examples): | |
| check_example_labels(dataset[idx], tokenizer, text_only=text_only) | |
| def check_example_labels(example, tokenizer, text_only=False): | |
| # Get the input_ids, labels, and attention_mask from the dataset | |
| input_ids = example["input_ids"] | |
| labels = example["labels"] | |
| attention_mask = example["attention_mask"] | |
| # You can compare the input_ids and labels element-wise | |
| # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 | |
| colored_tokens = [] | |
| for _, (input_id, label_id, mask) in enumerate( | |
| zip(input_ids, labels, attention_mask) | |
| ): | |
| decoded_input_token = tokenizer.decode(input_id) | |
| # Choose the color based on whether the label has the ignore value or not | |
| color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") | |
| colored_token = colored(decoded_input_token, color) + ( | |
| not text_only | |
| and colored(f"({label_id}, {mask}, {input_id})", "white") | |
| or "" | |
| ) | |
| colored_tokens.append(colored_token) | |
| LOG.info(" ".join(colored_tokens)) | |
| LOG.info("\n\n\n") | |
| return " ".join(colored_tokens) | |