Spaces:
Runtime error
Runtime error
| """ | |
| AttackArgs Class | |
| ================ | |
| """ | |
| from dataclasses import dataclass, field | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from typing import Dict, Optional | |
| import textattack | |
| from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file | |
| from .attack import Attack | |
| from .dataset_args import DatasetArgs | |
| from .model_args import ModelArgs | |
| ATTACK_RECIPE_NAMES = { | |
| "alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018", | |
| "bae": "textattack.attack_recipes.BAEGarg2019", | |
| "bert-attack": "textattack.attack_recipes.BERTAttackLi2020", | |
| "faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019", | |
| "deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018", | |
| "hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017", | |
| "input-reduction": "textattack.attack_recipes.InputReductionFeng2018", | |
| "kuleshov": "textattack.attack_recipes.Kuleshov2017", | |
| "morpheus": "textattack.attack_recipes.MorpheusTan2020", | |
| "seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox", | |
| "textbugger": "textattack.attack_recipes.TextBuggerLi2018", | |
| "textfooler": "textattack.attack_recipes.TextFoolerJin2019", | |
| "pwws": "textattack.attack_recipes.PWWSRen2019", | |
| "iga": "textattack.attack_recipes.IGAWang2019", | |
| "pruthi": "textattack.attack_recipes.Pruthi2019", | |
| "pso": "textattack.attack_recipes.PSOZang2020", | |
| "checklist": "textattack.attack_recipes.CheckList2020", | |
| "clare": "textattack.attack_recipes.CLARE2020", | |
| "a2t": "textattack.attack_recipes.A2TYoo2021", | |
| } | |
| BLACK_BOX_TRANSFORMATION_CLASS_NAMES = { | |
| "random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion", | |
| "word-deletion": "textattack.transformations.WordDeletion", | |
| "word-swap-embedding": "textattack.transformations.WordSwapEmbedding", | |
| "word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap", | |
| "word-swap-inflections": "textattack.transformations.WordSwapInflections", | |
| "word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap", | |
| "word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion", | |
| "word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion", | |
| "word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution", | |
| "word-swap-wordnet": "textattack.transformations.WordSwapWordNet", | |
| "word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM", | |
| "word-swap-hownet": "textattack.transformations.WordSwapHowNet", | |
| "word-swap-qwerty": "textattack.transformations.WordSwapQWERTY", | |
| } | |
| WHITE_BOX_TRANSFORMATION_CLASS_NAMES = { | |
| "word-swap-gradient": "textattack.transformations.WordSwapGradientBased" | |
| } | |
| CONSTRAINT_CLASS_NAMES = { | |
| # | |
| # Semantics constraints | |
| # | |
| "embedding": "textattack.constraints.semantics.WordEmbeddingDistance", | |
| "bert": "textattack.constraints.semantics.sentence_encoders.BERT", | |
| "infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent", | |
| "thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector", | |
| "use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder", | |
| "muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder", | |
| "bert-score": "textattack.constraints.semantics.BERTScore", | |
| # | |
| # Grammaticality constraints | |
| # | |
| "lang-tool": "textattack.constraints.grammaticality.LanguageTool", | |
| "part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech", | |
| "goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel", | |
| "gpt2": "textattack.constraints.grammaticality.language_models.GPT2", | |
| "learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel", | |
| "cola": "textattack.constraints.grammaticality.COLA", | |
| # | |
| # Overlap constraints | |
| # | |
| "bleu": "textattack.constraints.overlap.BLEU", | |
| "chrf": "textattack.constraints.overlap.chrF", | |
| "edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance", | |
| "meteor": "textattack.constraints.overlap.METEOR", | |
| "max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed", | |
| # | |
| # Pre-transformation constraints | |
| # | |
| "repeat": "textattack.constraints.pre_transformation.RepeatModification", | |
| "stopword": "textattack.constraints.pre_transformation.StopwordModification", | |
| "max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification", | |
| } | |
| SEARCH_METHOD_CLASS_NAMES = { | |
| "beam-search": "textattack.search_methods.BeamSearch", | |
| "greedy": "textattack.search_methods.GreedySearch", | |
| "ga-word": "textattack.search_methods.GeneticAlgorithm", | |
| "greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR", | |
| "pso": "textattack.search_methods.ParticleSwarmOptimization", | |
| } | |
| GOAL_FUNCTION_CLASS_NAMES = { | |
| # | |
| # Classification goal functions | |
| # | |
| "targeted-classification": "textattack.goal_functions.classification.TargetedClassification", | |
| "untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification", | |
| "input-reduction": "textattack.goal_functions.classification.InputReduction", | |
| # | |
| # Text goal functions | |
| # | |
| "minimize-bleu": "textattack.goal_functions.text.MinimizeBleu", | |
| "non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput", | |
| "text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction", | |
| } | |
| class AttackArgs: | |
| """Attack arguments to be passed to :class:`~textattack.Attacker`. | |
| Args: | |
| num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`): | |
| The number of examples to attack. :obj:`-1` for entire dataset. | |
| num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`): | |
| The number of successful adversarial examples we want. This is different from :obj:`num_examples` | |
| as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking | |
| until we have `N` successful cases. | |
| .. note:: | |
| If set, this argument overrides `num_examples` argument. | |
| num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`): | |
| The offset index to start at in the dataset. | |
| attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Whether to run attack until total of `N` examples have been attacked (and not skipped). | |
| shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling | |
| the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means | |
| :obj:`shuffle` can now be used with checkpoint saving. | |
| query_budget (:obj:`int`, `optional`, defaults to :obj:`None`): | |
| The maximum number of model queries allowed per example attacked. | |
| If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`). | |
| .. note:: | |
| Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object. | |
| checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`): | |
| If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved. | |
| checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`): | |
| The directory to save checkpoint files. | |
| random_seed (:obj:`int`, `optional`, defaults to :obj:`765`): | |
| Random seed for reproducibility. | |
| parallel (:obj:`False`, `optional`, defaults to :obj:`False`): | |
| If :obj:`True`, run attack using multiple CPUs/GPUs. | |
| num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`): | |
| Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`, | |
| then 2 processes will be running in each GPU. | |
| log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`): | |
| If set, save attack logs as a `.txt` file to the directory specified by this argument. | |
| If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file. | |
| log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`): | |
| If set, save attack logs as a CSV file to the directory specified by this argument. | |
| If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file. | |
| csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`): | |
| Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`. | |
| :obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way. | |
| log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`): | |
| If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`. | |
| Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following | |
| three keys and their corresponding values: :obj:`"env", "port", "hostname"`. | |
| log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`): | |
| If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`. | |
| Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following | |
| key and its corresponding value: :obj:`"project"`. | |
| disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Disable displaying individual attack results to stdout. | |
| silent (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`. | |
| enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc. | |
| """ | |
| num_examples: int = 10 | |
| num_successful_examples: int = None | |
| num_examples_offset: int = 0 | |
| attack_n: bool = False | |
| shuffle: bool = False | |
| query_budget: int = None | |
| checkpoint_interval: int = None | |
| checkpoint_dir: str = "checkpoints" | |
| random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK")) | |
| parallel: bool = False | |
| num_workers_per_device: int = 1 | |
| log_to_txt: str = None | |
| log_to_csv: str = None | |
| log_summary_to_json: str = None | |
| csv_coloring_style: str = "file" | |
| log_to_visdom: dict = None | |
| log_to_wandb: dict = None | |
| disable_stdout: bool = False | |
| silent: bool = False | |
| enable_advance_metrics: bool = False | |
| metrics: Optional[Dict] = None | |
| def __post_init__(self): | |
| if self.num_successful_examples: | |
| self.num_examples = None | |
| if self.num_examples: | |
| assert ( | |
| self.num_examples >= 0 or self.num_examples == -1 | |
| ), "`num_examples` must be greater than or equal to 0 or equal to -1." | |
| if self.num_successful_examples: | |
| assert ( | |
| self.num_successful_examples >= 0 | |
| ), "`num_examples` must be greater than or equal to 0." | |
| if self.query_budget: | |
| assert self.query_budget > 0, "`query_budget` must be greater than 0." | |
| if self.checkpoint_interval: | |
| assert ( | |
| self.checkpoint_interval > 0 | |
| ), "`checkpoint_interval` must be greater than 0." | |
| assert ( | |
| self.num_workers_per_device > 0 | |
| ), "`num_workers_per_device` must be greater than 0." | |
| def _add_parser_args(cls, parser): | |
| """Add listed args to command line parser.""" | |
| default_obj = cls() | |
| num_ex_group = parser.add_mutually_exclusive_group(required=False) | |
| num_ex_group.add_argument( | |
| "--num-examples", | |
| "-n", | |
| type=int, | |
| default=default_obj.num_examples, | |
| help="The number of examples to process, -1 for entire dataset.", | |
| ) | |
| num_ex_group.add_argument( | |
| "--num-successful-examples", | |
| type=int, | |
| default=default_obj.num_successful_examples, | |
| help="The number of successful adversarial examples we want.", | |
| ) | |
| parser.add_argument( | |
| "--num-examples-offset", | |
| "-o", | |
| type=int, | |
| required=False, | |
| default=default_obj.num_examples_offset, | |
| help="The offset to start at in the dataset.", | |
| ) | |
| parser.add_argument( | |
| "--query-budget", | |
| "-q", | |
| type=int, | |
| default=default_obj.query_budget, | |
| help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.", | |
| ) | |
| parser.add_argument( | |
| "--shuffle", | |
| action="store_true", | |
| default=default_obj.shuffle, | |
| help="If `True`, shuffle the samples before we attack the dataset. Default is False.", | |
| ) | |
| parser.add_argument( | |
| "--attack-n", | |
| action="store_true", | |
| default=default_obj.attack_n, | |
| help="Whether to run attack until `n` examples have been attacked (not skipped).", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-dir", | |
| required=False, | |
| type=str, | |
| default=default_obj.checkpoint_dir, | |
| help="The directory to save checkpoint files.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-interval", | |
| required=False, | |
| type=int, | |
| default=default_obj.checkpoint_interval, | |
| help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.", | |
| ) | |
| parser.add_argument( | |
| "--random-seed", | |
| default=default_obj.random_seed, | |
| type=int, | |
| help="Random seed for reproducibility.", | |
| ) | |
| parser.add_argument( | |
| "--parallel", | |
| action="store_true", | |
| default=default_obj.parallel, | |
| help="Run attack using multiple GPUs.", | |
| ) | |
| parser.add_argument( | |
| "--num-workers-per-device", | |
| default=default_obj.num_workers_per_device, | |
| type=int, | |
| help="Number of worker processes to run per device.", | |
| ) | |
| parser.add_argument( | |
| "--log-to-txt", | |
| nargs="?", | |
| default=default_obj.log_to_txt, | |
| const="", | |
| type=str, | |
| help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. " | |
| "If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.", | |
| ) | |
| parser.add_argument( | |
| "--log-to-csv", | |
| nargs="?", | |
| default=default_obj.log_to_csv, | |
| const="", | |
| type=str, | |
| help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. " | |
| "If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.", | |
| ) | |
| parser.add_argument( | |
| "--log-summary-to-json", | |
| nargs="?", | |
| default=default_obj.log_summary_to_json, | |
| const="", | |
| type=str, | |
| help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. " | |
| "If the last part of the path ends with `.json` extension, the path is assumed to path for output file.", | |
| ) | |
| parser.add_argument( | |
| "--csv-coloring-style", | |
| default=default_obj.csv_coloring_style, | |
| type=str, | |
| help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". ' | |
| '"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".', | |
| ) | |
| parser.add_argument( | |
| "--log-to-visdom", | |
| nargs="?", | |
| default=None, | |
| const='{"env": "main", "port": 8097, "hostname": "localhost"}', | |
| type=json.loads, | |
| help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following " | |
| 'three keys and their corresponding values: `"env", "port", "hostname"`. ' | |
| 'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.', | |
| ) | |
| parser.add_argument( | |
| "--log-to-wandb", | |
| nargs="?", | |
| default=None, | |
| const='{"project": "textattack"}', | |
| type=json.loads, | |
| help="Set this argument if you want to log attacks to WandB. The dictionary should have the following " | |
| 'key and its corresponding value: `"project"`. ' | |
| 'Example for command line use: `--log-to-wandb {"project": "textattack"}`.', | |
| ) | |
| parser.add_argument( | |
| "--disable-stdout", | |
| action="store_true", | |
| default=default_obj.disable_stdout, | |
| help="Disable logging attack results to stdout", | |
| ) | |
| parser.add_argument( | |
| "--silent", | |
| action="store_true", | |
| default=default_obj.silent, | |
| help="Disable all logging", | |
| ) | |
| parser.add_argument( | |
| "--enable-advance-metrics", | |
| action="store_true", | |
| default=default_obj.enable_advance_metrics, | |
| help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.", | |
| ) | |
| return parser | |
| def create_loggers_from_args(cls, args): | |
| """Creates AttackLogManager from an AttackArgs object.""" | |
| assert isinstance( | |
| args, cls | |
| ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." | |
| # Create logger | |
| attack_log_manager = textattack.loggers.AttackLogManager(args.metrics) | |
| # Get current time for file naming | |
| timestamp = time.strftime("%Y-%m-%d-%H-%M") | |
| # if '--log-to-txt' specified with arguments | |
| if args.log_to_txt is not None: | |
| if args.log_to_txt.lower().endswith(".txt"): | |
| txt_file_path = args.log_to_txt | |
| else: | |
| txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt") | |
| dir_path = os.path.dirname(txt_file_path) | |
| dir_path = dir_path if dir_path else "." | |
| if not os.path.exists(dir_path): | |
| os.makedirs(os.path.dirname(txt_file_path)) | |
| color_method = "file" | |
| attack_log_manager.add_output_file(txt_file_path, color_method) | |
| # if '--log-to-csv' specified with arguments | |
| if args.log_to_csv is not None: | |
| if args.log_to_csv.lower().endswith(".csv"): | |
| csv_file_path = args.log_to_csv | |
| else: | |
| csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv") | |
| dir_path = os.path.dirname(csv_file_path) | |
| dir_path = dir_path if dir_path else "." | |
| if not os.path.exists(dir_path): | |
| os.makedirs(dir_path) | |
| color_method = ( | |
| None if args.csv_coloring_style == "plain" else args.csv_coloring_style | |
| ) | |
| attack_log_manager.add_output_csv(csv_file_path, color_method) | |
| # if '--log-summary-to-json' specified with arguments | |
| if args.log_summary_to_json is not None: | |
| if args.log_summary_to_json.lower().endswith(".json"): | |
| summary_json_file_path = args.log_summary_to_json | |
| else: | |
| summary_json_file_path = os.path.join( | |
| args.log_summary_to_json, f"{timestamp}-attack_summary_log.json" | |
| ) | |
| dir_path = os.path.dirname(summary_json_file_path) | |
| dir_path = dir_path if dir_path else "." | |
| if not os.path.exists(dir_path): | |
| os.makedirs(os.path.dirname(summary_json_file_path)) | |
| attack_log_manager.add_output_summary_json(summary_json_file_path) | |
| # Visdom | |
| if args.log_to_visdom is not None: | |
| attack_log_manager.enable_visdom(**args.log_to_visdom) | |
| # Weights & Biases | |
| if args.log_to_wandb is not None: | |
| attack_log_manager.enable_wandb(**args.log_to_wandb) | |
| # Stdout | |
| if not args.disable_stdout and not sys.stdout.isatty(): | |
| attack_log_manager.disable_color() | |
| elif not args.disable_stdout: | |
| attack_log_manager.enable_stdout() | |
| return attack_log_manager | |
| class _CommandLineAttackArgs: | |
| """Attack args for command line execution. This requires more arguments to | |
| create ``Attack`` object as specified. | |
| Args: | |
| transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`): | |
| Name of transformation to use. | |
| constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`): | |
| List of names of constraints to use. | |
| goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`): | |
| Name of goal function to use. | |
| search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`): | |
| Name of search method to use. | |
| attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`): | |
| Name of attack recipe to use. | |
| .. note:: | |
| Setting this overrides any previous selection of transformation, constraints, goal function, and search method. | |
| attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`): | |
| Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file. | |
| .. note:: | |
| If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method | |
| interactive (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| If `True`, carry attack in interactive mode. | |
| parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
| If `True`, attack in parallel. | |
| model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`): | |
| The batch size for making queries to the victim model. | |
| model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`): | |
| The maximum number of items to keep in the model results cache at once. | |
| constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`): | |
| The maximum number of items to keep in the constraints cache at once. | |
| """ | |
| transformation: str = "word-swap-embedding" | |
| constraints: list = field(default_factory=lambda: ["repeat", "stopword"]) | |
| goal_function: str = "untargeted-classification" | |
| search_method: str = "greedy-word-wir" | |
| attack_recipe: str = None | |
| attack_from_file: str = None | |
| interactive: bool = False | |
| parallel: bool = False | |
| model_batch_size: int = 32 | |
| model_cache_size: int = 2**18 | |
| constraint_cache_size: int = 2**18 | |
| def _add_parser_args(cls, parser): | |
| """Add listed args to command line parser.""" | |
| default_obj = cls() | |
| transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set( | |
| WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys() | |
| ) | |
| parser.add_argument( | |
| "--transformation", | |
| type=str, | |
| required=False, | |
| default=default_obj.transformation, | |
| help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ' | |
| + str(transformation_names), | |
| ) | |
| parser.add_argument( | |
| "--constraints", | |
| type=str, | |
| required=False, | |
| nargs="*", | |
| default=default_obj.constraints, | |
| help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: ' | |
| + str(CONSTRAINT_CLASS_NAMES.keys()), | |
| ) | |
| goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys()) | |
| parser.add_argument( | |
| "--goal-function", | |
| "-g", | |
| default=default_obj.goal_function, | |
| help=f"The goal function to use. choices: {goal_function_choices}", | |
| ) | |
| attack_group = parser.add_mutually_exclusive_group(required=False) | |
| search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys()) | |
| attack_group.add_argument( | |
| "--search-method", | |
| "--search", | |
| "-s", | |
| type=str, | |
| required=False, | |
| default=default_obj.search_method, | |
| help=f"The search method to use. choices: {search_choices}", | |
| ) | |
| attack_group.add_argument( | |
| "--attack-recipe", | |
| "--recipe", | |
| "-r", | |
| type=str, | |
| required=False, | |
| default=default_obj.attack_recipe, | |
| help="full attack recipe (overrides provided goal function, transformation & constraints)", | |
| choices=ATTACK_RECIPE_NAMES.keys(), | |
| ) | |
| attack_group.add_argument( | |
| "--attack-from-file", | |
| type=str, | |
| required=False, | |
| default=default_obj.attack_from_file, | |
| help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.", | |
| ) | |
| parser.add_argument( | |
| "--interactive", | |
| action="store_true", | |
| default=default_obj.interactive, | |
| help="Whether to run attacks interactively.", | |
| ) | |
| parser.add_argument( | |
| "--model-batch-size", | |
| type=int, | |
| default=default_obj.model_batch_size, | |
| help="The batch size for making calls to the model.", | |
| ) | |
| parser.add_argument( | |
| "--model-cache-size", | |
| type=int, | |
| default=default_obj.model_cache_size, | |
| help="The maximum number of items to keep in the model results cache at once.", | |
| ) | |
| parser.add_argument( | |
| "--constraint-cache-size", | |
| type=int, | |
| default=default_obj.constraint_cache_size, | |
| help="The maximum number of items to keep in the constraints cache at once.", | |
| ) | |
| return parser | |
| def _create_transformation_from_args(cls, args, model_wrapper): | |
| """Create `Transformation` based on provided `args` and | |
| `model_wrapper`.""" | |
| transformation_name = args.transformation | |
| if ARGS_SPLIT_TOKEN in transformation_name: | |
| transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN) | |
| if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES: | |
| transformation = eval( | |
| f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})" | |
| ) | |
| elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES: | |
| transformation = eval( | |
| f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})" | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Error: unsupported transformation {transformation_name}" | |
| ) | |
| else: | |
| if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES: | |
| transformation = eval( | |
| f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)" | |
| ) | |
| elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES: | |
| transformation = eval( | |
| f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()" | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Error: unsupported transformation {transformation_name}" | |
| ) | |
| return transformation | |
| def _create_goal_function_from_args(cls, args, model_wrapper): | |
| """Create `GoalFunction` based on provided `args` and | |
| `model_wrapper`.""" | |
| goal_function = args.goal_function | |
| if ARGS_SPLIT_TOKEN in goal_function: | |
| goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN) | |
| if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES: | |
| raise ValueError( | |
| f"Error: unsupported goal_function {goal_function_name}" | |
| ) | |
| goal_function = eval( | |
| f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})" | |
| ) | |
| elif goal_function in GOAL_FUNCTION_CLASS_NAMES: | |
| goal_function = eval( | |
| f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)" | |
| ) | |
| else: | |
| raise ValueError(f"Error: unsupported goal_function {goal_function}") | |
| if args.query_budget: | |
| goal_function.query_budget = args.query_budget | |
| goal_function.model_cache_size = args.model_cache_size | |
| goal_function.batch_size = args.model_batch_size | |
| return goal_function | |
| def _create_constraints_from_args(cls, args): | |
| """Create list of `Constraints` based on provided `args`.""" | |
| if not args.constraints: | |
| return [] | |
| _constraints = [] | |
| for constraint in args.constraints: | |
| if ARGS_SPLIT_TOKEN in constraint: | |
| constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN) | |
| if constraint_name not in CONSTRAINT_CLASS_NAMES: | |
| raise ValueError(f"Error: unsupported constraint {constraint_name}") | |
| _constraints.append( | |
| eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})") | |
| ) | |
| elif constraint in CONSTRAINT_CLASS_NAMES: | |
| _constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()")) | |
| else: | |
| raise ValueError(f"Error: unsupported constraint {constraint}") | |
| return _constraints | |
| def _create_attack_from_args(cls, args, model_wrapper): | |
| """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified | |
| ``Attack`` object.""" | |
| assert isinstance( | |
| args, cls | |
| ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." | |
| if args.attack_recipe: | |
| if ARGS_SPLIT_TOKEN in args.attack_recipe: | |
| recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN) | |
| if recipe_name not in ATTACK_RECIPE_NAMES: | |
| raise ValueError(f"Error: unsupported recipe {recipe_name}") | |
| recipe = eval( | |
| f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})" | |
| ) | |
| elif args.attack_recipe in ATTACK_RECIPE_NAMES: | |
| recipe = eval( | |
| f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)" | |
| ) | |
| else: | |
| raise ValueError(f"Invalid recipe {args.attack_recipe}") | |
| if args.query_budget: | |
| recipe.goal_function.query_budget = args.query_budget | |
| recipe.goal_function.model_cache_size = args.model_cache_size | |
| recipe.constraint_cache_size = args.constraint_cache_size | |
| return recipe | |
| elif args.attack_from_file: | |
| if ARGS_SPLIT_TOKEN in args.attack_from_file: | |
| attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN) | |
| else: | |
| attack_file, attack_name = args.attack_from_file, "attack" | |
| attack_module = load_module_from_file(attack_file) | |
| if not hasattr(attack_module, attack_name): | |
| raise ValueError( | |
| f"Loaded `{attack_file}` but could not find `{attack_name}`." | |
| ) | |
| attack_func = getattr(attack_module, attack_name) | |
| return attack_func(model_wrapper) | |
| else: | |
| goal_function = cls._create_goal_function_from_args(args, model_wrapper) | |
| transformation = cls._create_transformation_from_args(args, model_wrapper) | |
| constraints = cls._create_constraints_from_args(args) | |
| if ARGS_SPLIT_TOKEN in args.search_method: | |
| search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN) | |
| if search_name not in SEARCH_METHOD_CLASS_NAMES: | |
| raise ValueError(f"Error: unsupported search {search_name}") | |
| search_method = eval( | |
| f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})" | |
| ) | |
| elif args.search_method in SEARCH_METHOD_CLASS_NAMES: | |
| search_method = eval( | |
| f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()" | |
| ) | |
| else: | |
| raise ValueError(f"Error: unsupported attack {args.search_method}") | |
| return Attack( | |
| goal_function, | |
| constraints, | |
| transformation, | |
| search_method, | |
| constraint_cache_size=args.constraint_cache_size, | |
| ) | |
| # This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass. | |
| # https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses | |
| class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs): | |
| def _add_parser_args(cls, parser): | |
| """Add listed args to command line parser.""" | |
| parser = ModelArgs._add_parser_args(parser) | |
| parser = DatasetArgs._add_parser_args(parser) | |
| parser = _CommandLineAttackArgs._add_parser_args(parser) | |
| parser = AttackArgs._add_parser_args(parser) | |
| return parser | |