Spaces:
Runtime error
Runtime error
""" | |
A2T (A2T: Attack for Adversarial Training Recipe) | |
================================================== | |
""" | |
from textattack import Attack | |
from textattack.constraints.grammaticality import PartOfSpeech | |
from textattack.constraints.pre_transformation import ( | |
InputColumnModification, | |
MaxModificationRate, | |
RepeatModification, | |
StopwordModification, | |
) | |
from textattack.constraints.semantics import WordEmbeddingDistance | |
from textattack.constraints.semantics.sentence_encoders import BERT | |
from textattack.goal_functions import UntargetedClassification | |
from textattack.search_methods import GreedyWordSwapWIR | |
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM | |
from .attack_recipe import AttackRecipe | |
class A2TYoo2021(AttackRecipe): | |
"""Towards Improving Adversarial Training of NLP Models. | |
(Yoo et al., 2021) | |
https://arxiv.org/abs/2109.00544 | |
""" | |
def build(model_wrapper, mlm=False): | |
"""Build attack recipe. | |
Args: | |
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): | |
Model wrapper containing both the model and the tokenizer. | |
mlm (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack. | |
Returns: | |
:class:`~textattack.Attack`: A2T attack. | |
""" | |
constraints = [RepeatModification(), StopwordModification()] | |
input_column_modification = InputColumnModification( | |
["premise", "hypothesis"], {"premise"} | |
) | |
constraints.append(input_column_modification) | |
constraints.append(PartOfSpeech(allow_verb_noun_swap=False)) | |
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4)) | |
sent_encoder = BERT( | |
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine" | |
) | |
constraints.append(sent_encoder) | |
if mlm: | |
transformation = transformation = WordSwapMaskedLM( | |
method="bae", max_candidates=20, min_confidence=0.0, batch_size=16 | |
) | |
else: | |
transformation = WordSwapEmbedding(max_candidates=20) | |
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8)) | |
# | |
# Goal is untargeted classification | |
# | |
goal_function = UntargetedClassification(model_wrapper, model_batch_size=32) | |
# | |
# Greedily swap words with "Word Importance Ranking". | |
# | |
search_method = GreedyWordSwapWIR(wir_method="gradient") | |
return Attack(goal_function, constraints, transformation, search_method) | |