Spaces:
Runtime error
Runtime error
| """ | |
| PeekDatasetCommand class | |
| ============================== | |
| """ | |
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
| import collections | |
| import re | |
| import numpy as np | |
| import textattack | |
| from textattack.commands import TextAttackCommand | |
| def _cb(s): | |
| return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | |
| logger = textattack.shared.logger | |
| class PeekDatasetCommand(TextAttackCommand): | |
| """The peek dataset module: | |
| Takes a peek into a dataset in textattack. | |
| """ | |
| def run(self, args): | |
| UPPERCASE_LETTERS_REGEX = re.compile("[A-Z]") | |
| dataset_args = textattack.DatasetArgs(**vars(args)) | |
| dataset = textattack.DatasetArgs._create_dataset_from_args(dataset_args) | |
| num_words = [] | |
| attacked_texts = [] | |
| data_all_lowercased = True | |
| outputs = [] | |
| for inputs, output in dataset: | |
| at = textattack.shared.AttackedText(inputs) | |
| if data_all_lowercased: | |
| # Test if any of the letters in the string are lowercase. | |
| if re.search(UPPERCASE_LETTERS_REGEX, at.text): | |
| data_all_lowercased = False | |
| attacked_texts.append(at) | |
| num_words.append(len(at.words)) | |
| outputs.append(output) | |
| logger.info(f"Number of samples: {_cb(len(attacked_texts))}") | |
| logger.info("Number of words per input:") | |
| num_words = np.array(num_words) | |
| logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}') | |
| mean_words = f"{num_words.mean():.2f}" | |
| logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}') | |
| std_words = f"{num_words.std():.2f}" | |
| logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}') | |
| logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}') | |
| logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}') | |
| logger.info(f"Dataset lowercased: {_cb(data_all_lowercased)}") | |
| logger.info("First sample:") | |
| print(attacked_texts[0].printable_text(), "\n") | |
| logger.info("Last sample:") | |
| print(attacked_texts[-1].printable_text(), "\n") | |
| logger.info(f"Found {len(set(outputs))} distinct outputs.") | |
| if len(outputs) < 20: | |
| print(sorted(set(outputs))) | |
| logger.info("Most common outputs:") | |
| for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)): | |
| print("\t", str(key)[:5].ljust(5), f" ({value})") | |
| def register_subcommand(main_parser: ArgumentParser): | |
| parser = main_parser.add_parser( | |
| "peek-dataset", | |
| help="show main statistics about a dataset", | |
| formatter_class=ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser = textattack.DatasetArgs._add_parser_args(parser) | |
| parser.set_defaults(func=PeekDatasetCommand()) | |