anonymous8/RPD-Demo
commited on
Commit
·
4943752
0
Parent(s):
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +31 -0
- .gitignore +143 -0
- README.md +13 -0
- anonymous_demo/__init__.py +5 -0
- anonymous_demo/core/__init__.py +0 -0
- anonymous_demo/core/tad/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/README.MD +3 -0
- anonymous_demo/core/tad/classic/__bert__/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py +0 -0
- anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +116 -0
- anonymous_demo/core/tad/classic/__bert__/models/__init__.py +1 -0
- anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +43 -0
- anonymous_demo/core/tad/classic/__init__.py +0 -0
- anonymous_demo/core/tad/models/__init__.py +9 -0
- anonymous_demo/core/tad/prediction/__init__.py +0 -0
- anonymous_demo/core/tad/prediction/tad_classifier.py +390 -0
- anonymous_demo/functional/__init__.py +3 -0
- anonymous_demo/functional/checkpoint/__init__.py +1 -0
- anonymous_demo/functional/checkpoint/checkpoint_manager.py +20 -0
- anonymous_demo/functional/config/__init__.py +1 -0
- anonymous_demo/functional/config/config_manager.py +66 -0
- anonymous_demo/functional/config/tad_config_manager.py +221 -0
- anonymous_demo/functional/dataset/__init__.py +1 -0
- anonymous_demo/functional/dataset/dataset_manager.py +21 -0
- anonymous_demo/network/__init__.py +0 -0
- anonymous_demo/network/lcf_pooler.py +26 -0
- anonymous_demo/network/lsa.py +52 -0
- anonymous_demo/network/sa_encoder.py +159 -0
- anonymous_demo/utils/__init__.py +0 -0
- anonymous_demo/utils/demo_utils.py +209 -0
- anonymous_demo/utils/logger.py +38 -0
- app.py +271 -0
- checkpoints.zip +3 -0
- requirements.txt +19 -0
- text_defense/201.SST2/stsa.binary.dev.dat +0 -0
- text_defense/201.SST2/stsa.binary.test.dat +0 -0
- text_defense/201.SST2/stsa.binary.train.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.test.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.train.dat +0 -0
- text_defense/204.AGNews10K/AGNews10K.valid.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat +0 -0
- text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat +0 -0
- textattack/__init__.py +39 -0
- textattack/__main__.py +6 -0
- textattack/attack.py +492 -0
- textattack/attack_args.py +763 -0
- textattack/attack_recipes/__init__.py +43 -0
- textattack/attack_recipes/a2t_yoo_2021.py +74 -0
- textattack/attack_recipes/attack_recipe.py +30 -0
- textattack/attack_recipes/bae_garg_2019.py +123 -0
.gitattributes
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dev files
|
| 2 |
+
*.cache
|
| 3 |
+
*.dev.py
|
| 4 |
+
state_dict/
|
| 5 |
+
|
| 6 |
+
# Byte-compiled / optimized / DLL files
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*$py.class
|
| 10 |
+
*.pyc
|
| 11 |
+
tests/
|
| 12 |
+
*.result.json
|
| 13 |
+
.idea/
|
| 14 |
+
|
| 15 |
+
# Embedding
|
| 16 |
+
glove.840B.300d.txt
|
| 17 |
+
glove.42B.300d.txt
|
| 18 |
+
glove.twitter.27B.txt
|
| 19 |
+
|
| 20 |
+
# project main files
|
| 21 |
+
release_note.json
|
| 22 |
+
|
| 23 |
+
# C extensions
|
| 24 |
+
*.so
|
| 25 |
+
|
| 26 |
+
# Distribution / packaging
|
| 27 |
+
.Python
|
| 28 |
+
build/
|
| 29 |
+
develop-eggs/
|
| 30 |
+
dist/
|
| 31 |
+
downloads/
|
| 32 |
+
eggs/
|
| 33 |
+
.eggs/
|
| 34 |
+
lib64/
|
| 35 |
+
parts/
|
| 36 |
+
sdist/
|
| 37 |
+
var/
|
| 38 |
+
wheels/
|
| 39 |
+
pip-wheel-metadata/
|
| 40 |
+
share/python-wheels/
|
| 41 |
+
*.egg-info/
|
| 42 |
+
.installed.cfg
|
| 43 |
+
*.egg
|
| 44 |
+
MANIFEST
|
| 45 |
+
|
| 46 |
+
# PyInstaller
|
| 47 |
+
# Usually these files are written by a python script from a template
|
| 48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 49 |
+
*.manifest
|
| 50 |
+
*.spec
|
| 51 |
+
|
| 52 |
+
# Installer training_logs
|
| 53 |
+
pip-log.txt
|
| 54 |
+
pip-delete-this-directory.txt
|
| 55 |
+
|
| 56 |
+
# Unit test / coverage reports
|
| 57 |
+
htmlcov/
|
| 58 |
+
.tox/
|
| 59 |
+
.nox/
|
| 60 |
+
.coverage
|
| 61 |
+
.coverage.*
|
| 62 |
+
.cache
|
| 63 |
+
nosetests.xml
|
| 64 |
+
coverage.xml
|
| 65 |
+
*.cover
|
| 66 |
+
*.py,cover
|
| 67 |
+
.hypothesis/
|
| 68 |
+
.pytest_cache/
|
| 69 |
+
|
| 70 |
+
# Translations
|
| 71 |
+
*.mo
|
| 72 |
+
*.pot
|
| 73 |
+
|
| 74 |
+
# Django stuff:
|
| 75 |
+
*.log
|
| 76 |
+
local_settings.py
|
| 77 |
+
db.sqlite3
|
| 78 |
+
db.sqlite3-journal
|
| 79 |
+
|
| 80 |
+
# Flask stuff:
|
| 81 |
+
instance/
|
| 82 |
+
.webassets-cache
|
| 83 |
+
|
| 84 |
+
# Scrapy stuff:
|
| 85 |
+
.scrapy
|
| 86 |
+
|
| 87 |
+
# Sphinx documentation
|
| 88 |
+
docs/_build/
|
| 89 |
+
|
| 90 |
+
# PyBuilder
|
| 91 |
+
target/
|
| 92 |
+
|
| 93 |
+
# Jupyter Notebook
|
| 94 |
+
.ipynb_checkpoints
|
| 95 |
+
|
| 96 |
+
# IPython
|
| 97 |
+
profile_default/
|
| 98 |
+
ipython_config.py
|
| 99 |
+
|
| 100 |
+
# pyenv
|
| 101 |
+
.python-version
|
| 102 |
+
|
| 103 |
+
# pipenv
|
| 104 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 105 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 106 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 107 |
+
# install all needed dependencies.
|
| 108 |
+
#Pipfile.lock
|
| 109 |
+
|
| 110 |
+
# celery beat schedule file
|
| 111 |
+
celerybeat-schedule
|
| 112 |
+
|
| 113 |
+
# SageMath parsed files
|
| 114 |
+
*.sage.py
|
| 115 |
+
|
| 116 |
+
# Environments
|
| 117 |
+
.env
|
| 118 |
+
.venv
|
| 119 |
+
env/
|
| 120 |
+
venv/
|
| 121 |
+
ENV/
|
| 122 |
+
env.bak/
|
| 123 |
+
venv.bak/
|
| 124 |
+
|
| 125 |
+
# Spyder project settings
|
| 126 |
+
.spyderproject
|
| 127 |
+
.spyproject
|
| 128 |
+
|
| 129 |
+
# Rope project settings
|
| 130 |
+
.ropeproject
|
| 131 |
+
|
| 132 |
+
# mkdocs documentation
|
| 133 |
+
/site
|
| 134 |
+
|
| 135 |
+
# mypy
|
| 136 |
+
.mypy_cache/
|
| 137 |
+
.dmypy.json
|
| 138 |
+
dmypy.json
|
| 139 |
+
|
| 140 |
+
# Pyre type checker
|
| 141 |
+
.pyre/
|
| 142 |
+
.DS_Store
|
| 143 |
+
examples/.DS_Store
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: RPD-Demo
|
| 3 |
+
emoji: 🛡️
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.0.19
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
anonymous_demo/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = '1.0.0'
|
| 2 |
+
|
| 3 |
+
__name__ = 'anonymous_demo'
|
| 4 |
+
|
| 5 |
+
from anonymous_demo.functional import TADCheckpointManager
|
anonymous_demo/core/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/classic/__bert__/README.MD
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## This is the simple migration from ABSA-PyTorch under MIT license
|
| 2 |
+
|
| 3 |
+
Project Address: https://github.com/songyouwei/ABSA-PyTorch
|
anonymous_demo/core/tad/classic/__bert__/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .models import *
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
from findfile import find_cwd_dir
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Tokenizer4Pretraining:
|
| 8 |
+
def __init__(self, max_seq_len, opt, **kwargs):
|
| 9 |
+
if kwargs.pop('offline', False):
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(find_cwd_dir(opt.pretrained_bert.split('/')[-1]),
|
| 11 |
+
do_lower_case='uncased' in opt.pretrained_bert)
|
| 12 |
+
else:
|
| 13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(opt.pretrained_bert,
|
| 14 |
+
do_lower_case='uncased' in opt.pretrained_bert)
|
| 15 |
+
self.max_seq_len = max_seq_len
|
| 16 |
+
|
| 17 |
+
def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
|
| 18 |
+
|
| 19 |
+
return self.tokenizer.encode(text, truncation=True, padding='max_length', max_length=self.max_seq_len,
|
| 20 |
+
return_tensors='pt')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BERTTADDataset(Dataset):
|
| 24 |
+
|
| 25 |
+
def __init__(self, tokenizer, opt):
|
| 26 |
+
self.bert_baseline_input_colses = {
|
| 27 |
+
'bert': ['text_bert_indices']
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
self.tokenizer = tokenizer
|
| 31 |
+
self.opt = opt
|
| 32 |
+
self.all_data = []
|
| 33 |
+
|
| 34 |
+
def parse_sample(self, text):
|
| 35 |
+
return [text]
|
| 36 |
+
|
| 37 |
+
def prepare_infer_sample(self, text: str, ignore_error):
|
| 38 |
+
self.process_data(self.parse_sample(text), ignore_error=ignore_error)
|
| 39 |
+
|
| 40 |
+
def process_data(self, samples, ignore_error=True):
|
| 41 |
+
all_data = []
|
| 42 |
+
if len(samples) > 100:
|
| 43 |
+
it = tqdm.tqdm(samples, postfix='preparing text classification inference dataloader...')
|
| 44 |
+
else:
|
| 45 |
+
it = samples
|
| 46 |
+
for text in it:
|
| 47 |
+
try:
|
| 48 |
+
# handle for empty lines in inference datasets
|
| 49 |
+
if text is None or '' == text.strip():
|
| 50 |
+
raise RuntimeError('Invalid Input!')
|
| 51 |
+
|
| 52 |
+
if '!ref!' in text:
|
| 53 |
+
text, _, labels = text.strip().partition('!ref!')
|
| 54 |
+
text = text.strip()
|
| 55 |
+
if labels.count(',') == 2:
|
| 56 |
+
label, is_adv, adv_train_label = labels.strip().split(',')
|
| 57 |
+
label, is_adv, adv_train_label = label.strip(), is_adv.strip(), adv_train_label.strip()
|
| 58 |
+
elif labels.count(',') == 1:
|
| 59 |
+
label, is_adv = labels.strip().split(',')
|
| 60 |
+
label, is_adv = label.strip(), is_adv.strip()
|
| 61 |
+
adv_train_label = '-100'
|
| 62 |
+
elif labels.count(',') == 0:
|
| 63 |
+
label = labels.strip()
|
| 64 |
+
adv_train_label = '-100'
|
| 65 |
+
is_adv = '-100'
|
| 66 |
+
else:
|
| 67 |
+
label = '-100'
|
| 68 |
+
adv_train_label = '-100'
|
| 69 |
+
is_adv = '-100'
|
| 70 |
+
|
| 71 |
+
label = int(label)
|
| 72 |
+
adv_train_label = int(adv_train_label)
|
| 73 |
+
is_adv = int(is_adv)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
text = text.strip()
|
| 77 |
+
label = -100
|
| 78 |
+
adv_train_label = -100
|
| 79 |
+
is_adv = -100
|
| 80 |
+
|
| 81 |
+
text_indices = self.tokenizer.text_to_sequence('{}'.format(text))
|
| 82 |
+
|
| 83 |
+
data = {
|
| 84 |
+
'text_bert_indices': text_indices[0],
|
| 85 |
+
|
| 86 |
+
'text_raw': text,
|
| 87 |
+
|
| 88 |
+
'label': label,
|
| 89 |
+
|
| 90 |
+
'adv_train_label': adv_train_label,
|
| 91 |
+
|
| 92 |
+
'is_adv': is_adv,
|
| 93 |
+
|
| 94 |
+
# 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
|
| 95 |
+
#
|
| 96 |
+
# 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
|
| 97 |
+
#
|
| 98 |
+
# 'is_adv': self.opt.is_adv_to_index.get(is_adv, -100) if isinstance(is_adv, str) else is_adv,
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
all_data.append(data)
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
if ignore_error:
|
| 105 |
+
print('Ignore error while processing:', text)
|
| 106 |
+
else:
|
| 107 |
+
raise e
|
| 108 |
+
|
| 109 |
+
self.all_data = all_data
|
| 110 |
+
return self.all_data
|
| 111 |
+
|
| 112 |
+
def __getitem__(self, index):
|
| 113 |
+
return self.all_data[index]
|
| 114 |
+
|
| 115 |
+
def __len__(self):
|
| 116 |
+
return len(self.all_data)
|
anonymous_demo/core/tad/classic/__bert__/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tad_bert import TADBERT
|
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers.models.bert.modeling_bert import BertPooler
|
| 4 |
+
|
| 5 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TADBERT(nn.Module):
|
| 9 |
+
inputs = ['text_bert_indices']
|
| 10 |
+
|
| 11 |
+
def __init__(self, bert, opt):
|
| 12 |
+
super(TADBERT, self).__init__()
|
| 13 |
+
self.opt = opt
|
| 14 |
+
self.bert = bert
|
| 15 |
+
self.pooler = BertPooler(bert.config)
|
| 16 |
+
self.dense1 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
| 17 |
+
self.dense2 = nn.Linear(self.opt.hidden_dim, self.opt.adv_det_dim)
|
| 18 |
+
self.dense3 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
|
| 19 |
+
|
| 20 |
+
self.encoder1 = Encoder(self.bert.config, opt=opt)
|
| 21 |
+
self.encoder2 = Encoder(self.bert.config, opt=opt)
|
| 22 |
+
self.encoder3 = Encoder(self.bert.config, opt=opt)
|
| 23 |
+
|
| 24 |
+
def forward(self, inputs):
|
| 25 |
+
text_raw_indices = inputs[0]
|
| 26 |
+
last_hidden_state = self.bert(text_raw_indices)['last_hidden_state']
|
| 27 |
+
|
| 28 |
+
sent_logits = self.dense1(self.pooler(last_hidden_state))
|
| 29 |
+
advdet_logits = self.dense2(self.pooler(last_hidden_state))
|
| 30 |
+
adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
|
| 31 |
+
|
| 32 |
+
att_score = torch.nn.functional.normalize(
|
| 33 |
+
last_hidden_state.abs().sum(dim=1, keepdim=False) - last_hidden_state.abs().min(dim=1, keepdim=True)[0],
|
| 34 |
+
p=1, dim=1)
|
| 35 |
+
|
| 36 |
+
outputs = {
|
| 37 |
+
'sent_logits': sent_logits,
|
| 38 |
+
'advdet_logits': advdet_logits,
|
| 39 |
+
'adv_tr_logits': adv_tr_logits,
|
| 40 |
+
'last_hidden_state': last_hidden_state,
|
| 41 |
+
'att_score': att_score
|
| 42 |
+
}
|
| 43 |
+
return outputs
|
anonymous_demo/core/tad/classic/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import anonymous_demo.core.tad.classic.__bert__.models
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BERTTADModelList(list):
|
| 5 |
+
TADBERT = anonymous_demo.core.tad.classic.__bert__.TADBERT
|
| 6 |
+
|
| 7 |
+
def __init__(self):
|
| 8 |
+
model_list = [self.TADBERT]
|
| 9 |
+
super().__init__(model_list)
|
anonymous_demo/core/tad/prediction/__init__.py
ADDED
|
File without changes
|
anonymous_demo/core/tad/prediction/tad_classifier.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import tqdm
|
| 8 |
+
from findfile import find_file, find_cwd_dir
|
| 9 |
+
from termcolor import colored
|
| 10 |
+
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig, DebertaV2ForMaskedLM, RobertaForMaskedLM, BertForMaskedLM
|
| 13 |
+
|
| 14 |
+
from ....functional.dataset.dataset_manager import detect_infer_dataset
|
| 15 |
+
|
| 16 |
+
from ..models import BERTTADModelList
|
| 17 |
+
from ..classic.__bert__.dataset_utils.data_utils_for_inference import BERTTADDataset, Tokenizer4Pretraining
|
| 18 |
+
|
| 19 |
+
from ....utils.demo_utils import print_args, TransformerConnectionError, get_device, build_embedding_matrix
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def init_attacker(tad_classifier, defense):
|
| 23 |
+
try:
|
| 24 |
+
from textattack import Attacker
|
| 25 |
+
from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, \
|
| 26 |
+
GeneticAlgorithmAlzantot2018, DeepWordBugGao2018
|
| 27 |
+
from textattack.datasets import Dataset
|
| 28 |
+
from textattack.models.wrappers import HuggingFaceModelWrapper
|
| 29 |
+
|
| 30 |
+
class DemoModelWrapper(HuggingFaceModelWrapper):
|
| 31 |
+
def __init__(self, model):
|
| 32 |
+
self.model = model # pipeline = pipeline
|
| 33 |
+
|
| 34 |
+
def __call__(self, text_inputs, **kwargs):
|
| 35 |
+
outputs = []
|
| 36 |
+
for text_input in text_inputs:
|
| 37 |
+
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
|
| 38 |
+
outputs.append(raw_outputs['probs'])
|
| 39 |
+
return outputs
|
| 40 |
+
|
| 41 |
+
class SentAttacker:
|
| 42 |
+
|
| 43 |
+
def __init__(self, model, recipe_class=BAEGarg2019):
|
| 44 |
+
model = model
|
| 45 |
+
model_wrapper = DemoModelWrapper(model)
|
| 46 |
+
|
| 47 |
+
recipe = recipe_class.build(model_wrapper)
|
| 48 |
+
|
| 49 |
+
_dataset = [('', 0)]
|
| 50 |
+
_dataset = Dataset(_dataset)
|
| 51 |
+
|
| 52 |
+
self.attacker = Attacker(recipe, _dataset)
|
| 53 |
+
|
| 54 |
+
attackers = {
|
| 55 |
+
'bae': BAEGarg2019,
|
| 56 |
+
'pwws': PWWSRen2019,
|
| 57 |
+
'textfooler': TextFoolerJin2019,
|
| 58 |
+
'pso': PSOZang2020,
|
| 59 |
+
'iga': IGAWang2019,
|
| 60 |
+
'ga': GeneticAlgorithmAlzantot2018,
|
| 61 |
+
'wordbugger': DeepWordBugGao2018,
|
| 62 |
+
}
|
| 63 |
+
return SentAttacker(tad_classifier, attackers[defense])
|
| 64 |
+
except Exception as e:
|
| 65 |
+
|
| 66 |
+
print('Original error:', e)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_mlm_and_tokenizer(text_classifier, config):
|
| 70 |
+
if isinstance(text_classifier, TADTextClassifier):
|
| 71 |
+
base_model = text_classifier.model.bert.base_model
|
| 72 |
+
else:
|
| 73 |
+
base_model = text_classifier.bert.base_model
|
| 74 |
+
pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
|
| 75 |
+
if 'deberta-v3' in config.pretrained_bert:
|
| 76 |
+
MLM = DebertaV2ForMaskedLM(pretrained_config)
|
| 77 |
+
MLM.deberta = base_model
|
| 78 |
+
elif 'roberta' in config.pretrained_bert:
|
| 79 |
+
MLM = RobertaForMaskedLM(pretrained_config)
|
| 80 |
+
MLM.roberta = base_model
|
| 81 |
+
else:
|
| 82 |
+
MLM = BertForMaskedLM(pretrained_config)
|
| 83 |
+
MLM.bert = base_model
|
| 84 |
+
return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TADTextClassifier:
|
| 88 |
+
def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
|
| 89 |
+
'''
|
| 90 |
+
from_train_model: load inference model from trained model
|
| 91 |
+
'''
|
| 92 |
+
self.cal_perplexity = cal_perplexity
|
| 93 |
+
# load from a training
|
| 94 |
+
if not isinstance(model_arg, str):
|
| 95 |
+
print('Load text classifier from training')
|
| 96 |
+
self.model = model_arg[0]
|
| 97 |
+
self.opt = model_arg[1]
|
| 98 |
+
self.tokenizer = model_arg[2]
|
| 99 |
+
else:
|
| 100 |
+
try:
|
| 101 |
+
if 'fine-tuned' in model_arg:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
'Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!')
|
| 104 |
+
print('Load text classifier from', model_arg)
|
| 105 |
+
state_dict_path = find_file(model_arg, key='.state_dict', exclude_key=['__MACOSX'])
|
| 106 |
+
model_path = find_file(model_arg, key='.model', exclude_key=['__MACOSX'])
|
| 107 |
+
tokenizer_path = find_file(model_arg, key='.tokenizer', exclude_key=['__MACOSX'])
|
| 108 |
+
config_path = find_file(model_arg, key='.config', exclude_key=['__MACOSX'])
|
| 109 |
+
|
| 110 |
+
print('config: {}'.format(config_path))
|
| 111 |
+
print('state_dict: {}'.format(state_dict_path))
|
| 112 |
+
print('model: {}'.format(model_path))
|
| 113 |
+
print('tokenizer: {}'.format(tokenizer_path))
|
| 114 |
+
|
| 115 |
+
with open(config_path, mode='rb') as f:
|
| 116 |
+
self.opt = pickle.load(f)
|
| 117 |
+
self.opt.device = get_device(kwargs.pop('auto_device', True))[0]
|
| 118 |
+
|
| 119 |
+
if state_dict_path or model_path:
|
| 120 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
| 121 |
+
if state_dict_path:
|
| 122 |
+
if kwargs.pop('offline', False):
|
| 123 |
+
self.bert = AutoModel.from_pretrained(
|
| 124 |
+
find_cwd_dir(self.opt.pretrained_bert.split('/')[-1]))
|
| 125 |
+
else:
|
| 126 |
+
self.bert = AutoModel.from_pretrained(self.opt.pretrained_bert)
|
| 127 |
+
self.model = self.opt.model(self.bert, self.opt)
|
| 128 |
+
self.model.load_state_dict(torch.load(state_dict_path, map_location='cpu'))
|
| 129 |
+
elif model_path:
|
| 130 |
+
self.model = torch.load(model_path, map_location='cpu')
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
self.tokenizer = Tokenizer4Pretraining(max_seq_len=self.opt.max_seq_len, opt=self.opt,
|
| 134 |
+
**kwargs)
|
| 135 |
+
except ValueError:
|
| 136 |
+
if tokenizer_path:
|
| 137 |
+
with open(tokenizer_path, mode='rb') as f:
|
| 138 |
+
self.tokenizer = pickle.load(f)
|
| 139 |
+
else:
|
| 140 |
+
raise TransformerConnectionError()
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise RuntimeError('Exception: {} Fail to load the model from {}! '.format(e, model_arg))
|
| 144 |
+
|
| 145 |
+
self.infer_dataloader = None
|
| 146 |
+
self.opt.eval_batch_size = kwargs.pop('eval_batch_size', 128)
|
| 147 |
+
|
| 148 |
+
self.opt.initializer = self.opt.initializer
|
| 149 |
+
|
| 150 |
+
if self.cal_perplexity:
|
| 151 |
+
try:
|
| 152 |
+
self.MLM, self.MLM_tokenizer = get_mlm_and_tokenizer(self, self.opt)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
self.MLM, self.MLM_tokenizer = None, None
|
| 155 |
+
|
| 156 |
+
self.to(self.opt.device)
|
| 157 |
+
|
| 158 |
+
def to(self, device=None):
|
| 159 |
+
self.opt.device = device
|
| 160 |
+
self.model.to(device)
|
| 161 |
+
if hasattr(self, 'MLM'):
|
| 162 |
+
self.MLM.to(self.opt.device)
|
| 163 |
+
|
| 164 |
+
def cpu(self):
|
| 165 |
+
self.opt.device = 'cpu'
|
| 166 |
+
self.model.to('cpu')
|
| 167 |
+
if hasattr(self, 'MLM'):
|
| 168 |
+
self.MLM.to('cpu')
|
| 169 |
+
|
| 170 |
+
def cuda(self, device='cuda:0'):
|
| 171 |
+
self.opt.device = device
|
| 172 |
+
self.model.to(device)
|
| 173 |
+
if hasattr(self, 'MLM'):
|
| 174 |
+
self.MLM.to(device)
|
| 175 |
+
|
| 176 |
+
def _log_write_args(self):
|
| 177 |
+
n_trainable_params, n_nontrainable_params = 0, 0
|
| 178 |
+
for p in self.model.parameters():
|
| 179 |
+
n_params = torch.prod(torch.tensor(p.shape))
|
| 180 |
+
if p.requires_grad:
|
| 181 |
+
n_trainable_params += n_params
|
| 182 |
+
else:
|
| 183 |
+
n_nontrainable_params += n_params
|
| 184 |
+
print(
|
| 185 |
+
'n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
|
| 186 |
+
for arg in vars(self.opt):
|
| 187 |
+
if getattr(self.opt, arg) is not None:
|
| 188 |
+
print('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))
|
| 189 |
+
|
| 190 |
+
def batch_infer(self,
|
| 191 |
+
target_file=None,
|
| 192 |
+
print_result=True,
|
| 193 |
+
save_result=False,
|
| 194 |
+
ignore_error=True,
|
| 195 |
+
defense: str = None
|
| 196 |
+
):
|
| 197 |
+
|
| 198 |
+
save_path = os.path.join(os.getcwd(), 'tad_text_classification.result.json')
|
| 199 |
+
|
| 200 |
+
target_file = detect_infer_dataset(target_file, task='text_defense')
|
| 201 |
+
if not target_file:
|
| 202 |
+
raise FileNotFoundError('Can not find inference datasets!')
|
| 203 |
+
|
| 204 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
| 205 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
| 206 |
+
|
| 207 |
+
dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
|
| 208 |
+
self.infer_dataloader = DataLoader(dataset=dataset, batch_size=self.opt.eval_batch_size, pin_memory=True,
|
| 209 |
+
shuffle=False)
|
| 210 |
+
return self._infer(save_path=save_path if save_result else None, print_result=print_result, defense=defense)
|
| 211 |
+
|
| 212 |
+
def infer(self,
|
| 213 |
+
text: str = None,
|
| 214 |
+
print_result=True,
|
| 215 |
+
ignore_error=True,
|
| 216 |
+
defense: str = None
|
| 217 |
+
):
|
| 218 |
+
|
| 219 |
+
if hasattr(BERTTADModelList, self.opt.model.__name__):
|
| 220 |
+
dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
|
| 221 |
+
|
| 222 |
+
if text:
|
| 223 |
+
dataset.prepare_infer_sample(text, ignore_error=ignore_error)
|
| 224 |
+
else:
|
| 225 |
+
raise RuntimeError('Please specify your datasets path!')
|
| 226 |
+
self.infer_dataloader = DataLoader(dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False)
|
| 227 |
+
return self._infer(print_result=print_result, defense=defense)[0]
|
| 228 |
+
|
| 229 |
+
def _infer(self, save_path=None, print_result=True, defense=None):
|
| 230 |
+
|
| 231 |
+
_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
| 232 |
+
|
| 233 |
+
correct = {True: 'Correct', False: 'Wrong'}
|
| 234 |
+
results = []
|
| 235 |
+
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
self.model.eval()
|
| 238 |
+
n_correct = 0
|
| 239 |
+
n_labeled = 0
|
| 240 |
+
|
| 241 |
+
n_advdet_correct = 0
|
| 242 |
+
n_advdet_labeled = 0
|
| 243 |
+
if len(self.infer_dataloader.dataset) >= 100:
|
| 244 |
+
it = tqdm.tqdm(self.infer_dataloader, postfix='inferring...')
|
| 245 |
+
else:
|
| 246 |
+
it = self.infer_dataloader
|
| 247 |
+
for _, sample in enumerate(it):
|
| 248 |
+
inputs = [sample[col].to(self.opt.device) for col in self.opt.inputs_cols]
|
| 249 |
+
outputs = self.model(inputs)
|
| 250 |
+
logits, advdet_logits, adv_tr_logits = outputs['sent_logits'], outputs['advdet_logits'], outputs[
|
| 251 |
+
'adv_tr_logits']
|
| 252 |
+
probs, advdet_probs, adv_tr_probs = torch.softmax(logits, dim=-1), torch.softmax(advdet_logits,
|
| 253 |
+
dim=-1), torch.softmax(
|
| 254 |
+
adv_tr_logits, dim=-1)
|
| 255 |
+
|
| 256 |
+
for i, (prob, advdet_prob, adv_tr_prob) in enumerate(zip(probs, advdet_probs, adv_tr_probs)):
|
| 257 |
+
text_raw = sample['text_raw'][i]
|
| 258 |
+
|
| 259 |
+
pred_label = int(prob.argmax(axis=-1))
|
| 260 |
+
pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
|
| 261 |
+
pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
|
| 262 |
+
ref_label = int(sample['label'][i]) if int(sample['label'][i]) in self.opt.index_to_label else ''
|
| 263 |
+
ref_is_adv_label = int(sample['is_adv'][i]) if int(
|
| 264 |
+
sample['is_adv'][i]) in self.opt.index_to_is_adv else ''
|
| 265 |
+
ref_adv_tr_label = int(sample['adv_train_label'][i]) if int(
|
| 266 |
+
sample['adv_train_label'][i]) in self.opt.index_to_adv_train_label else ''
|
| 267 |
+
|
| 268 |
+
if self.cal_perplexity:
|
| 269 |
+
ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
|
| 270 |
+
ids['labels'] = ids['input_ids'].clone()
|
| 271 |
+
ids = ids.to(self.opt.device)
|
| 272 |
+
loss = self.MLM(**ids)['loss']
|
| 273 |
+
perplexity = float(torch.exp(loss / ids['input_ids'].size(1)))
|
| 274 |
+
else:
|
| 275 |
+
perplexity = 'N.A.'
|
| 276 |
+
|
| 277 |
+
result = {
|
| 278 |
+
'text': text_raw,
|
| 279 |
+
|
| 280 |
+
'label': self.opt.index_to_label[pred_label],
|
| 281 |
+
'probs': prob.cpu().numpy(),
|
| 282 |
+
'confidence': float(max(prob)),
|
| 283 |
+
'ref_label': self.opt.index_to_label[ref_label] if isinstance(ref_label, int) else ref_label,
|
| 284 |
+
'ref_label_check': correct[pred_label == ref_label] if ref_label != -100 else '',
|
| 285 |
+
'is_fixed': False,
|
| 286 |
+
|
| 287 |
+
'is_adv_label': self.opt.index_to_is_adv[pred_is_adv_label],
|
| 288 |
+
'is_adv_probs': advdet_prob.cpu().numpy(),
|
| 289 |
+
'is_adv_confidence': float(max(advdet_prob)),
|
| 290 |
+
'ref_is_adv_label': self.opt.index_to_is_adv[ref_is_adv_label] if isinstance(ref_is_adv_label, int) else ref_is_adv_label,
|
| 291 |
+
'ref_is_adv_check': correct[pred_is_adv_label == ref_is_adv_label] if ref_is_adv_label != -100 and isinstance(ref_is_adv_label, int) else '',
|
| 292 |
+
|
| 293 |
+
'pred_adv_tr_label': self.opt.index_to_label[pred_adv_tr_label],
|
| 294 |
+
'ref_adv_tr_label': self.opt.index_to_label[ref_adv_tr_label],
|
| 295 |
+
|
| 296 |
+
'perplexity': perplexity,
|
| 297 |
+
}
|
| 298 |
+
if defense:
|
| 299 |
+
try:
|
| 300 |
+
if not hasattr(self, 'sent_attacker'):
|
| 301 |
+
self.sent_attacker = init_attacker(self, defense.lower())
|
| 302 |
+
if result['is_adv_label'] == '1':
|
| 303 |
+
res = self.sent_attacker.attacker.simple_attack(text_raw, int(result['label']))
|
| 304 |
+
new_infer_res = self.infer(res.perturbed_result.attacked_text.text, print_result=False)
|
| 305 |
+
result['perturbed_label'] = result['label']
|
| 306 |
+
result['label'] = new_infer_res['label']
|
| 307 |
+
result['probs'] = new_infer_res['probs']
|
| 308 |
+
result['ref_label_check'] = correct[int(result['label']) == ref_label] if ref_label != -100 else ''
|
| 309 |
+
result['restored_text'] = res.perturbed_result.attacked_text.text
|
| 310 |
+
result['is_fixed'] = True
|
| 311 |
+
else:
|
| 312 |
+
result['restored_text'] = ''
|
| 313 |
+
result['is_fixed'] = False
|
| 314 |
+
|
| 315 |
+
except Exception as e:
|
| 316 |
+
print('Error:{}, try install TextAttack and tensorflow_text after 10 seconds...'.format(e))
|
| 317 |
+
time.sleep(10)
|
| 318 |
+
raise RuntimeError('Installation done, please run again...')
|
| 319 |
+
|
| 320 |
+
if ref_label != -100:
|
| 321 |
+
n_labeled += 1
|
| 322 |
+
|
| 323 |
+
if result['label'] == result['ref_label']:
|
| 324 |
+
n_correct += 1
|
| 325 |
+
|
| 326 |
+
if ref_is_adv_label != -100:
|
| 327 |
+
n_advdet_labeled += 1
|
| 328 |
+
if ref_is_adv_label == pred_is_adv_label:
|
| 329 |
+
n_advdet_correct += 1
|
| 330 |
+
|
| 331 |
+
results.append(result)
|
| 332 |
+
|
| 333 |
+
try:
|
| 334 |
+
if print_result:
|
| 335 |
+
for ex_id, result in enumerate(results):
|
| 336 |
+
text_printing = result['text'][:]
|
| 337 |
+
text_info = ''
|
| 338 |
+
if result['label'] != '-100':
|
| 339 |
+
if not result['ref_label']:
|
| 340 |
+
text_info += ' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'],
|
| 341 |
+
result['ref_label'],
|
| 342 |
+
result['confidence'])
|
| 343 |
+
elif result['label'] == result['ref_label']:
|
| 344 |
+
text_info += colored(
|
| 345 |
+
' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'], result['ref_label'],
|
| 346 |
+
result['confidence']), 'green')
|
| 347 |
+
else:
|
| 348 |
+
text_info += colored(
|
| 349 |
+
' -> <CLS:{}(ref:{} confidence:{})>'.format(result['label'], result['ref_label'],
|
| 350 |
+
result['confidence']), 'red')
|
| 351 |
+
|
| 352 |
+
# AdvDet
|
| 353 |
+
if result['is_adv_label'] != '-100':
|
| 354 |
+
if not result['ref_is_adv_label']:
|
| 355 |
+
text_info += ' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
|
| 356 |
+
result['ref_is_adv_check'],
|
| 357 |
+
result['is_adv_confidence'])
|
| 358 |
+
elif result['is_adv_label'] == result['ref_is_adv_label']:
|
| 359 |
+
text_info += colored(' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
|
| 360 |
+
result[
|
| 361 |
+
'ref_is_adv_label'],
|
| 362 |
+
result[
|
| 363 |
+
'is_adv_confidence']),
|
| 364 |
+
'green')
|
| 365 |
+
else:
|
| 366 |
+
text_info += colored(' -> <AdvDet:{}(ref:{} confidence:{})>'.format(result['is_adv_label'],
|
| 367 |
+
result[
|
| 368 |
+
'ref_is_adv_label'],
|
| 369 |
+
result[
|
| 370 |
+
'is_adv_confidence']),
|
| 371 |
+
'red')
|
| 372 |
+
text_printing += text_info
|
| 373 |
+
if self.cal_perplexity:
|
| 374 |
+
text_printing += colored(' --> <perplexity:{}>'.format(result['perplexity']), 'yellow')
|
| 375 |
+
print('Example {}: {}'.format(ex_id, text_printing))
|
| 376 |
+
if save_path:
|
| 377 |
+
with open(save_path, 'w', encoding='utf8') as fout:
|
| 378 |
+
json.dump(str(results), fout, ensure_ascii=False)
|
| 379 |
+
print('inference result saved in: {}'.format(save_path))
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print('Can not save result: {}, Exception: {}'.format(text_raw, e))
|
| 382 |
+
|
| 383 |
+
if len(results) > 1:
|
| 384 |
+
print('CLS Acc:{}%'.format(100 * n_correct / n_labeled if n_labeled else ''))
|
| 385 |
+
print('AdvDet Acc:{}%'.format(100 * n_advdet_correct / n_advdet_labeled if n_advdet_labeled else ''))
|
| 386 |
+
|
| 387 |
+
return results
|
| 388 |
+
|
| 389 |
+
def clear_input_samples(self):
|
| 390 |
+
self.dataset.all_data = []
|
anonymous_demo/functional/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from anonymous_demo.functional.checkpoint.checkpoint_manager import TADCheckpointManager
|
| 2 |
+
|
| 3 |
+
from anonymous_demo.functional.config import TADConfigManager
|
anonymous_demo/functional/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .checkpoint_manager import TADCheckpointManager
|
anonymous_demo/functional/checkpoint/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from findfile import find_file
|
| 3 |
+
|
| 4 |
+
from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier
|
| 5 |
+
from anonymous_demo.utils.demo_utils import retry
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CheckpointManager:
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TADCheckpointManager(CheckpointManager):
|
| 13 |
+
@staticmethod
|
| 14 |
+
@retry
|
| 15 |
+
def get_tad_text_classifier(checkpoint: str = None,
|
| 16 |
+
eval_batch_size=128,
|
| 17 |
+
**kwargs):
|
| 18 |
+
|
| 19 |
+
tad_text_classifier = TADTextClassifier(checkpoint, eval_batch_size=eval_batch_size, **kwargs)
|
| 20 |
+
return tad_text_classifier
|
anonymous_demo/functional/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tad_config_manager import TADConfigManager
|
anonymous_demo/functional/config/config_manager.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import Namespace
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
one_shot_messages = set()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def config_check(args):
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConfigManager(Namespace):
|
| 13 |
+
|
| 14 |
+
def __init__(self, args=None, **kwargs):
|
| 15 |
+
"""
|
| 16 |
+
The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
|
| 17 |
+
:param args: A parameter dict
|
| 18 |
+
:param kwargs: Same param as Namespce
|
| 19 |
+
"""
|
| 20 |
+
if not args:
|
| 21 |
+
args = {}
|
| 22 |
+
super().__init__(**kwargs)
|
| 23 |
+
|
| 24 |
+
if isinstance(args, Namespace):
|
| 25 |
+
self.args = vars(args)
|
| 26 |
+
self.args_call_count = {arg: 0 for arg in vars(args)}
|
| 27 |
+
else:
|
| 28 |
+
self.args = args
|
| 29 |
+
self.args_call_count = {arg: 0 for arg in args}
|
| 30 |
+
|
| 31 |
+
def __getattribute__(self, arg_name):
|
| 32 |
+
if arg_name == 'args' or arg_name == 'args_call_count':
|
| 33 |
+
return super().__getattribute__(arg_name)
|
| 34 |
+
try:
|
| 35 |
+
value = super().__getattribute__('args')[arg_name]
|
| 36 |
+
args_call_count = super().__getattribute__('args_call_count')
|
| 37 |
+
args_call_count[arg_name] += 1
|
| 38 |
+
super().__setattr__('args_call_count', args_call_count)
|
| 39 |
+
return value
|
| 40 |
+
|
| 41 |
+
except Exception as e:
|
| 42 |
+
|
| 43 |
+
return super().__getattribute__(arg_name)
|
| 44 |
+
|
| 45 |
+
def __setattr__(self, arg_name, value):
|
| 46 |
+
if arg_name == 'args' or arg_name == 'args_call_count':
|
| 47 |
+
super().__setattr__(arg_name, value)
|
| 48 |
+
return
|
| 49 |
+
try:
|
| 50 |
+
args = super().__getattribute__('args')
|
| 51 |
+
args[arg_name] = value
|
| 52 |
+
super().__setattr__('args', args)
|
| 53 |
+
args_call_count = super().__getattribute__('args_call_count')
|
| 54 |
+
|
| 55 |
+
if arg_name in args_call_count:
|
| 56 |
+
# args_call_count[arg_name] += 1
|
| 57 |
+
super().__setattr__('args_call_count', args_call_count)
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
args_call_count[arg_name] = 0
|
| 61 |
+
super().__setattr__('args_call_count', args_call_count)
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
super().__setattr__(arg_name, value)
|
| 65 |
+
|
| 66 |
+
config_check(args)
|
anonymous_demo/functional/config/tad_config_manager.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
from anonymous_demo.functional.config.config_manager import ConfigManager
|
| 4 |
+
from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
|
| 5 |
+
|
| 6 |
+
_tad_config_template = {'model': TADBERT,
|
| 7 |
+
'optimizer': "adamw",
|
| 8 |
+
'learning_rate': 0.00002,
|
| 9 |
+
'patience': 99999,
|
| 10 |
+
'pretrained_bert': "microsoft/mdeberta-v3-base",
|
| 11 |
+
'cache_dataset': True,
|
| 12 |
+
'warmup_step': -1,
|
| 13 |
+
'show_metric': False,
|
| 14 |
+
'max_seq_len': 80,
|
| 15 |
+
'dropout': 0,
|
| 16 |
+
'l2reg': 0.000001,
|
| 17 |
+
'num_epoch': 10,
|
| 18 |
+
'batch_size': 16,
|
| 19 |
+
'initializer': 'xavier_uniform_',
|
| 20 |
+
'seed': 52,
|
| 21 |
+
'polarities_dim': 3,
|
| 22 |
+
'log_step': 10,
|
| 23 |
+
'evaluate_begin': 0,
|
| 24 |
+
'cross_validate_fold': -1,
|
| 25 |
+
'use_amp': False,
|
| 26 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
_tad_config_base = {'model': TADBERT,
|
| 30 |
+
'optimizer': "adamw",
|
| 31 |
+
'learning_rate': 0.00002,
|
| 32 |
+
'pretrained_bert': "microsoft/deberta-v3-base",
|
| 33 |
+
'cache_dataset': True,
|
| 34 |
+
'warmup_step': -1,
|
| 35 |
+
'show_metric': False,
|
| 36 |
+
'max_seq_len': 80,
|
| 37 |
+
'patience': 99999,
|
| 38 |
+
'dropout': 0,
|
| 39 |
+
'l2reg': 0.000001,
|
| 40 |
+
'num_epoch': 10,
|
| 41 |
+
'batch_size': 16,
|
| 42 |
+
'initializer': 'xavier_uniform_',
|
| 43 |
+
'seed': 52,
|
| 44 |
+
'polarities_dim': 3,
|
| 45 |
+
'log_step': 10,
|
| 46 |
+
'evaluate_begin': 0,
|
| 47 |
+
'cross_validate_fold': -1
|
| 48 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
_tad_config_english = {'model': TADBERT,
|
| 52 |
+
'optimizer': "adamw",
|
| 53 |
+
'learning_rate': 0.00002,
|
| 54 |
+
'patience': 99999,
|
| 55 |
+
'pretrained_bert': "microsoft/deberta-v3-base",
|
| 56 |
+
'cache_dataset': True,
|
| 57 |
+
'warmup_step': -1,
|
| 58 |
+
'show_metric': False,
|
| 59 |
+
'max_seq_len': 80,
|
| 60 |
+
'dropout': 0,
|
| 61 |
+
'l2reg': 0.000001,
|
| 62 |
+
'num_epoch': 10,
|
| 63 |
+
'batch_size': 16,
|
| 64 |
+
'initializer': 'xavier_uniform_',
|
| 65 |
+
'seed': 52,
|
| 66 |
+
'polarities_dim': 3,
|
| 67 |
+
'log_step': 10,
|
| 68 |
+
'evaluate_begin': 0,
|
| 69 |
+
'cross_validate_fold': -1
|
| 70 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
_tad_config_multilingual = {'model': TADBERT,
|
| 74 |
+
'optimizer': "adamw",
|
| 75 |
+
'learning_rate': 0.00002,
|
| 76 |
+
'patience': 99999,
|
| 77 |
+
'pretrained_bert': "microsoft/mdeberta-v3-base",
|
| 78 |
+
'cache_dataset': True,
|
| 79 |
+
'warmup_step': -1,
|
| 80 |
+
'show_metric': False,
|
| 81 |
+
'max_seq_len': 80,
|
| 82 |
+
'dropout': 0,
|
| 83 |
+
'l2reg': 0.000001,
|
| 84 |
+
'num_epoch': 10,
|
| 85 |
+
'batch_size': 16,
|
| 86 |
+
'initializer': 'xavier_uniform_',
|
| 87 |
+
'seed': 52,
|
| 88 |
+
'polarities_dim': 3,
|
| 89 |
+
'log_step': 10,
|
| 90 |
+
'evaluate_begin': 0,
|
| 91 |
+
'cross_validate_fold': -1
|
| 92 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
_tad_config_chinese = {'model': TADBERT,
|
| 96 |
+
'optimizer': "adamw",
|
| 97 |
+
'learning_rate': 0.00002,
|
| 98 |
+
'patience': 99999,
|
| 99 |
+
'cache_dataset': True,
|
| 100 |
+
'warmup_step': -1,
|
| 101 |
+
'show_metric': False,
|
| 102 |
+
'pretrained_bert': "bert-base-chinese",
|
| 103 |
+
'max_seq_len': 80,
|
| 104 |
+
'dropout': 0,
|
| 105 |
+
'l2reg': 0.000001,
|
| 106 |
+
'num_epoch': 10,
|
| 107 |
+
'batch_size': 16,
|
| 108 |
+
'initializer': 'xavier_uniform_',
|
| 109 |
+
'seed': 52,
|
| 110 |
+
'polarities_dim': 3,
|
| 111 |
+
'log_step': 10,
|
| 112 |
+
'evaluate_begin': 0,
|
| 113 |
+
'cross_validate_fold': -1
|
| 114 |
+
# split train and test datasets into 5 folds and repeat 3 training
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TADConfigManager(ConfigManager):
|
| 119 |
+
def __init__(self, args, **kwargs):
|
| 120 |
+
"""
|
| 121 |
+
Available Params: {'model': BERT,
|
| 122 |
+
'optimizer': "adamw",
|
| 123 |
+
'learning_rate': 0.00002,
|
| 124 |
+
'pretrained_bert': "roberta-base",
|
| 125 |
+
'cache_dataset': True,
|
| 126 |
+
'warmup_step': -1,
|
| 127 |
+
'show_metric': False,
|
| 128 |
+
'max_seq_len': 80,
|
| 129 |
+
'patience': 99999,
|
| 130 |
+
'dropout': 0,
|
| 131 |
+
'l2reg': 0.000001,
|
| 132 |
+
'num_epoch': 10,
|
| 133 |
+
'batch_size': 16,
|
| 134 |
+
'initializer': 'xavier_uniform_',
|
| 135 |
+
'seed': {52, 25}
|
| 136 |
+
'embed_dim': 768,
|
| 137 |
+
'hidden_dim': 768,
|
| 138 |
+
'polarities_dim': 3,
|
| 139 |
+
'log_step': 10,
|
| 140 |
+
'evaluate_begin': 0,
|
| 141 |
+
'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
|
| 142 |
+
}
|
| 143 |
+
:param args:
|
| 144 |
+
:param kwargs:
|
| 145 |
+
"""
|
| 146 |
+
super().__init__(args, **kwargs)
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def set_tad_config(configType: str, newitem: dict):
|
| 150 |
+
if isinstance(newitem, dict):
|
| 151 |
+
if configType == 'template':
|
| 152 |
+
_tad_config_template.update(newitem)
|
| 153 |
+
elif configType == 'base':
|
| 154 |
+
_tad_config_base.update(newitem)
|
| 155 |
+
elif configType == 'english':
|
| 156 |
+
_tad_config_english.update(newitem)
|
| 157 |
+
elif configType == 'chinese':
|
| 158 |
+
_tad_config_chinese.update(newitem)
|
| 159 |
+
elif configType == 'multilingual':
|
| 160 |
+
_tad_config_multilingual.update(newitem)
|
| 161 |
+
elif configType == 'glove':
|
| 162 |
+
_tad_config_glove.update(newitem)
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove")
|
| 166 |
+
else:
|
| 167 |
+
raise TypeError("Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}")
|
| 168 |
+
|
| 169 |
+
@staticmethod
|
| 170 |
+
def set_tad_config_template(newitem):
|
| 171 |
+
TADConfigManager.set_tad_config('template', newitem)
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def set_tad_config_base(newitem):
|
| 175 |
+
TADConfigManager.set_tad_config('base', newitem)
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def set_tad_config_english(newitem):
|
| 179 |
+
TADConfigManager.set_tad_config('english', newitem)
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def set_tad_config_chinese(newitem):
|
| 183 |
+
TADConfigManager.set_tad_config('chinese', newitem)
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def set_tad_config_multilingual(newitem):
|
| 187 |
+
TADConfigManager.set_tad_config('multilingual', newitem)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def set_tad_config_glove(newitem):
|
| 191 |
+
TADConfigManager.set_tad_config('glove', newitem)
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
def get_tad_config_template() -> ConfigManager:
|
| 195 |
+
_tad_config_template.update(_tad_config_template)
|
| 196 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def get_tad_config_base() -> ConfigManager:
|
| 200 |
+
_tad_config_template.update(_tad_config_base)
|
| 201 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 202 |
+
|
| 203 |
+
@staticmethod
|
| 204 |
+
def get_tad_config_english() -> ConfigManager:
|
| 205 |
+
_tad_config_template.update(_tad_config_english)
|
| 206 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def get_tad_config_chinese() -> ConfigManager:
|
| 210 |
+
_tad_config_template.update(_tad_config_chinese)
|
| 211 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def get_tad_config_multilingual() -> ConfigManager:
|
| 215 |
+
_tad_config_template.update(_tad_config_multilingual)
|
| 216 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def get_tad_config_glove() -> ConfigManager:
|
| 220 |
+
_tad_config_template.update(_tad_config_glove)
|
| 221 |
+
return TADConfigManager(copy.deepcopy(_tad_config_template))
|
anonymous_demo/functional/dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from anonymous_demo.functional.dataset.dataset_manager import (detect_infer_dataset)
|
anonymous_demo/functional/dataset/dataset_manager.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from findfile import find_files, find_dir
|
| 3 |
+
|
| 4 |
+
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip',
|
| 5 |
+
'.state_dict', '.model', '.png', 'acc_', 'f1_', '.backup', '.bak']
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def detect_infer_dataset(dataset_path, task='apc'):
|
| 9 |
+
dataset_file = []
|
| 10 |
+
if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
|
| 11 |
+
dataset_file.append(dataset_path)
|
| 12 |
+
return dataset_file
|
| 13 |
+
|
| 14 |
+
for d in dataset_path:
|
| 15 |
+
if not os.path.exists(d):
|
| 16 |
+
search_path = find_dir(os.getcwd(), [d, task, 'dataset'], exclude_key=filter_key_words, disable_alert=False)
|
| 17 |
+
dataset_file += find_files(search_path, ['.inference', d], exclude_key=['train.'] + filter_key_words)
|
| 18 |
+
else:
|
| 19 |
+
dataset_file += find_files(d, ['.inference', task], exclude_key=['train.'] + filter_key_words)
|
| 20 |
+
|
| 21 |
+
return dataset_file
|
anonymous_demo/network/__init__.py
ADDED
|
File without changes
|
anonymous_demo/network/lcf_pooler.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LCF_Pooler(nn.Module):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.config = config
|
| 10 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 11 |
+
self.activation = nn.Tanh()
|
| 12 |
+
|
| 13 |
+
def forward(self, hidden_states, lcf_vec):
|
| 14 |
+
device = hidden_states.device
|
| 15 |
+
lcf_vec = lcf_vec.detach().cpu().numpy()
|
| 16 |
+
|
| 17 |
+
pooled_output = numpy.zeros((hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32)
|
| 18 |
+
hidden_states = hidden_states.detach().cpu().numpy()
|
| 19 |
+
for i, vec in enumerate(lcf_vec):
|
| 20 |
+
lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.) == 0]
|
| 21 |
+
pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
|
| 22 |
+
|
| 23 |
+
pooled_output = torch.Tensor(pooled_output).to(device)
|
| 24 |
+
pooled_output = self.dense(pooled_output)
|
| 25 |
+
pooled_output = self.activation(pooled_output)
|
| 26 |
+
return pooled_output
|
anonymous_demo/network/lsa.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from anonymous_demo.network.sa_encoder import Encoder
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LSA(nn.Module):
|
| 7 |
+
def __init__(self, bert, opt):
|
| 8 |
+
super(LSA, self).__init__()
|
| 9 |
+
self.opt = opt
|
| 10 |
+
|
| 11 |
+
self.encoder = Encoder(bert.config, opt)
|
| 12 |
+
self.encoder_left = Encoder(bert.config, opt)
|
| 13 |
+
self.encoder_right = Encoder(bert.config, opt)
|
| 14 |
+
self.linear_window_3h = nn.Linear(opt.embed_dim * 3, opt.embed_dim)
|
| 15 |
+
self.linear_window_2h = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
|
| 16 |
+
self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
| 17 |
+
self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
|
| 18 |
+
|
| 19 |
+
def forward(self, global_context_features, spc_mask_vec, lcf_matrix, left_lcf_matrix, right_lcf_matrix):
|
| 20 |
+
masked_global_context_features = torch.mul(spc_mask_vec, global_context_features)
|
| 21 |
+
|
| 22 |
+
# # --------------------------------------------------- #
|
| 23 |
+
lcf_features = torch.mul(global_context_features, lcf_matrix)
|
| 24 |
+
lcf_features = self.encoder(lcf_features)
|
| 25 |
+
# # --------------------------------------------------- #
|
| 26 |
+
left_lcf_features = torch.mul(masked_global_context_features, left_lcf_matrix)
|
| 27 |
+
left_lcf_features = self.encoder_left(left_lcf_features)
|
| 28 |
+
# # --------------------------------------------------- #
|
| 29 |
+
right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
|
| 30 |
+
right_lcf_features = self.encoder_right(right_lcf_features)
|
| 31 |
+
# # --------------------------------------------------- #
|
| 32 |
+
if 'lr' == self.opt.window or 'rl' == self.opt.window:
|
| 33 |
+
if self.eta1 <= 0 and self.opt.eta != -1:
|
| 34 |
+
torch.nn.init.uniform_(self.eta1)
|
| 35 |
+
print('reset eta1 to: {}'.format(self.eta1.item()))
|
| 36 |
+
if self.eta2 <= 0 and self.opt.eta != -1:
|
| 37 |
+
torch.nn.init.uniform_(self.eta2)
|
| 38 |
+
print('reset eta2 to: {}'.format(self.eta2.item()))
|
| 39 |
+
if self.opt.eta >= 0:
|
| 40 |
+
cat_features = torch.cat((lcf_features, self.eta1 * left_lcf_features, self.eta2 * right_lcf_features),
|
| 41 |
+
-1)
|
| 42 |
+
else:
|
| 43 |
+
cat_features = torch.cat((lcf_features, left_lcf_features, right_lcf_features), -1)
|
| 44 |
+
sent_out = self.linear_window_3h(cat_features)
|
| 45 |
+
elif 'l' == self.opt.window:
|
| 46 |
+
sent_out = self.linear_window_2h(torch.cat((lcf_features, self.eta1 * left_lcf_features), -1))
|
| 47 |
+
elif 'r' == self.opt.window:
|
| 48 |
+
sent_out = self.linear_window_2h(torch.cat((lcf_features, self.eta2 * right_lcf_features), -1))
|
| 49 |
+
else:
|
| 50 |
+
raise KeyError('Invalid parameter:', self.opt.window)
|
| 51 |
+
|
| 52 |
+
return sent_out
|
anonymous_demo/network/sa_encoder.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BertSelfAttention(nn.Module):
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
super().__init__()
|
| 11 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 12 |
+
raise ValueError(
|
| 13 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 14 |
+
f"heads ({config.num_attention_heads})"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
self.num_attention_heads = config.num_attention_heads
|
| 18 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 19 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 20 |
+
|
| 21 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 22 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 23 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 24 |
+
|
| 25 |
+
self.dropout = nn.Dropout(
|
| 26 |
+
config.attention_probs_dropout_prob if hasattr(config, 'attention_probs_dropout_prob') else 0)
|
| 27 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 28 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 29 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 30 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 31 |
+
|
| 32 |
+
self.is_decoder = config.is_decoder
|
| 33 |
+
|
| 34 |
+
def transpose_for_scores(self, x):
|
| 35 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 36 |
+
x = x.view(*new_x_shape)
|
| 37 |
+
return x.permute(0, 2, 1, 3)
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
hidden_states,
|
| 42 |
+
attention_mask=None,
|
| 43 |
+
head_mask=None,
|
| 44 |
+
encoder_hidden_states=None,
|
| 45 |
+
encoder_attention_mask=None,
|
| 46 |
+
past_key_value=None,
|
| 47 |
+
output_attentions=False,
|
| 48 |
+
):
|
| 49 |
+
mixed_query_layer = self.query(hidden_states)
|
| 50 |
+
|
| 51 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 52 |
+
# and values come from an encoder; the attention mask needs to be
|
| 53 |
+
# such that the encoder's padding tokens are not attended to.
|
| 54 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 55 |
+
|
| 56 |
+
if is_cross_attention and past_key_value is not None:
|
| 57 |
+
# reuse k,v, cross_attentions
|
| 58 |
+
key_layer = past_key_value[0]
|
| 59 |
+
value_layer = past_key_value[1]
|
| 60 |
+
attention_mask = encoder_attention_mask
|
| 61 |
+
elif is_cross_attention:
|
| 62 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 63 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 64 |
+
attention_mask = encoder_attention_mask
|
| 65 |
+
elif past_key_value is not None:
|
| 66 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 67 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 68 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 69 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 70 |
+
else:
|
| 71 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 72 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 73 |
+
|
| 74 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 75 |
+
|
| 76 |
+
if self.is_decoder:
|
| 77 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 78 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 79 |
+
# key/value_states (first "if" case)
|
| 80 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 81 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 82 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 83 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 84 |
+
past_key_value = (key_layer, value_layer)
|
| 85 |
+
|
| 86 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 87 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 88 |
+
|
| 89 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 90 |
+
seq_length = hidden_states.size()[1]
|
| 91 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 92 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 93 |
+
distance = position_ids_l - position_ids_r
|
| 94 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 95 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 96 |
+
|
| 97 |
+
if self.position_embedding_type == "relative_key":
|
| 98 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 99 |
+
attention_scores = attention_scores + relative_position_scores
|
| 100 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 101 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 102 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 103 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 104 |
+
|
| 105 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 106 |
+
if attention_mask is not None:
|
| 107 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 108 |
+
attention_scores = attention_scores + attention_mask
|
| 109 |
+
|
| 110 |
+
# Normalize the attention scores to probabilities.
|
| 111 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 112 |
+
|
| 113 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 114 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 115 |
+
attention_probs = self.dropout(attention_probs)
|
| 116 |
+
|
| 117 |
+
# Mask heads if we want to
|
| 118 |
+
if head_mask is not None:
|
| 119 |
+
attention_probs = attention_probs * head_mask
|
| 120 |
+
|
| 121 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 122 |
+
|
| 123 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 124 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 125 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 126 |
+
|
| 127 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 128 |
+
|
| 129 |
+
if self.is_decoder:
|
| 130 |
+
outputs = outputs + (past_key_value,)
|
| 131 |
+
return outputs
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Encoder(nn.Module):
|
| 135 |
+
def __init__(self, config, opt, layer_num=1):
|
| 136 |
+
super(Encoder, self).__init__()
|
| 137 |
+
self.opt = opt
|
| 138 |
+
self.config = config
|
| 139 |
+
self.encoder = nn.ModuleList([SelfAttention(config, opt) for _ in range(layer_num)])
|
| 140 |
+
self.tanh = torch.nn.Tanh()
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
for i, enc in enumerate(self.encoder):
|
| 144 |
+
x = self.tanh(enc(x)[0])
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class SelfAttention(nn.Module):
|
| 149 |
+
def __init__(self, config, opt):
|
| 150 |
+
super(SelfAttention, self).__init__()
|
| 151 |
+
self.opt = opt
|
| 152 |
+
self.config = config
|
| 153 |
+
self.SA = BertSelfAttention(config)
|
| 154 |
+
|
| 155 |
+
def forward(self, inputs):
|
| 156 |
+
zero_vec = np.zeros((inputs.size(0), 1, 1, self.opt.max_seq_len))
|
| 157 |
+
zero_tensor = torch.tensor(zero_vec).float().to(inputs.device)
|
| 158 |
+
SA_out = self.SA(inputs, zero_tensor)
|
| 159 |
+
return SA_out
|
anonymous_demo/utils/__init__.py
ADDED
|
File without changes
|
anonymous_demo/utils/demo_utils.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import signal
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
import zipfile
|
| 8 |
+
|
| 9 |
+
import gdown
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
import tqdm
|
| 14 |
+
from autocuda import auto_cuda, auto_cuda_name
|
| 15 |
+
from findfile import find_files, find_cwd_file, find_file
|
| 16 |
+
from termcolor import colored
|
| 17 |
+
from functools import wraps
|
| 18 |
+
|
| 19 |
+
from update_checker import parse_version
|
| 20 |
+
|
| 21 |
+
from anonymous_demo import __version__
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def save_args(config, save_path):
|
| 25 |
+
f = open(os.path.join(save_path), mode='w', encoding='utf8')
|
| 26 |
+
for arg in config.args:
|
| 27 |
+
if config.args_call_count[arg]:
|
| 28 |
+
f.write('{}: {}\n'.format(arg, config.args[arg]))
|
| 29 |
+
f.close()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def print_args(config, logger=None, mode=0):
|
| 33 |
+
args = [key for key in sorted(config.args.keys())]
|
| 34 |
+
for arg in args:
|
| 35 |
+
if logger:
|
| 36 |
+
logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
|
| 37 |
+
else:
|
| 38 |
+
print('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def check_and_fix_labels(label_set: set, label_name, all_data, opt):
|
| 42 |
+
if '-100' in label_set:
|
| 43 |
+
|
| 44 |
+
label_to_index = {origin_label: int(idx) - 1 if origin_label != '-100' else -100 for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
|
| 45 |
+
index_to_label = {int(idx) - 1 if origin_label != '-100' else -100: origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
|
| 46 |
+
else:
|
| 47 |
+
label_to_index = {origin_label: int(idx) for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
|
| 48 |
+
index_to_label = {int(idx): origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))}
|
| 49 |
+
if 'index_to_label' not in opt.args:
|
| 50 |
+
opt.index_to_label = index_to_label
|
| 51 |
+
opt.label_to_index = label_to_index
|
| 52 |
+
|
| 53 |
+
if opt.index_to_label != index_to_label:
|
| 54 |
+
opt.index_to_label.update(index_to_label)
|
| 55 |
+
opt.label_to_index.update(label_to_index)
|
| 56 |
+
num_label = {l: 0 for l in label_set}
|
| 57 |
+
num_label['Sum'] = len(all_data)
|
| 58 |
+
for item in all_data:
|
| 59 |
+
try:
|
| 60 |
+
num_label[item[label_name]] += 1
|
| 61 |
+
item[label_name] = label_to_index[item[label_name]]
|
| 62 |
+
except Exception as e:
|
| 63 |
+
# print(e)
|
| 64 |
+
num_label[item.polarity] += 1
|
| 65 |
+
item.polarity = label_to_index[item.polarity]
|
| 66 |
+
print('Dataset Label Details: {}'.format(num_label))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def check_and_fix_IOB_labels(label_map, opt):
|
| 70 |
+
index_to_IOB_label = {int(label_map[origin_label]): origin_label for origin_label in label_map}
|
| 71 |
+
opt.index_to_IOB_label = index_to_IOB_label
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_device(auto_device):
|
| 75 |
+
if isinstance(auto_device, str) and auto_device == 'allcuda':
|
| 76 |
+
device = 'cuda'
|
| 77 |
+
elif isinstance(auto_device, str):
|
| 78 |
+
device = auto_device
|
| 79 |
+
elif isinstance(auto_device, bool):
|
| 80 |
+
device = auto_cuda() if auto_device else 'cpu'
|
| 81 |
+
else:
|
| 82 |
+
device = auto_cuda()
|
| 83 |
+
try:
|
| 84 |
+
torch.device(device)
|
| 85 |
+
except RuntimeError as e:
|
| 86 |
+
print(colored('Device assignment error: {}, redirect to CPU'.format(e), 'red'))
|
| 87 |
+
device = 'cpu'
|
| 88 |
+
device_name = auto_cuda_name()
|
| 89 |
+
return device, device_name
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _load_word_vec(path, word2idx=None, embed_dim=300):
|
| 93 |
+
fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore')
|
| 94 |
+
word_vec = {}
|
| 95 |
+
for line in tqdm.tqdm(fin.readlines(), postfix='Loading embedding file...'):
|
| 96 |
+
tokens = line.rstrip().split()
|
| 97 |
+
word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:]
|
| 98 |
+
if word in word2idx.keys():
|
| 99 |
+
word_vec[word] = np.asarray(vec, dtype='float32')
|
| 100 |
+
return word_vec
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
|
| 104 |
+
if not os.path.exists('run'):
|
| 105 |
+
os.makedirs('run')
|
| 106 |
+
embed_matrix_path = 'run/{}'.format(os.path.join(opt.dataset_name, dat_fname))
|
| 107 |
+
if os.path.exists(embed_matrix_path):
|
| 108 |
+
print(colored('Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)'.format(embed_matrix_path), 'green'))
|
| 109 |
+
embedding_matrix = pickle.load(open(embed_matrix_path, 'rb'))
|
| 110 |
+
else:
|
| 111 |
+
glove_path = prepare_glove840_embedding(embed_matrix_path)
|
| 112 |
+
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
|
| 113 |
+
|
| 114 |
+
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
|
| 115 |
+
|
| 116 |
+
for word, i in tqdm.tqdm(word2idx.items(), postfix=colored('Building embedding_matrix {}'.format(dat_fname), 'yellow')):
|
| 117 |
+
vec = word_vec.get(word)
|
| 118 |
+
if vec is not None:
|
| 119 |
+
embedding_matrix[i] = vec
|
| 120 |
+
pickle.dump(embedding_matrix, open(embed_matrix_path, 'wb'))
|
| 121 |
+
return embedding_matrix
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
|
| 125 |
+
x = (np.ones(maxlen) * value).astype(dtype)
|
| 126 |
+
if truncating == 'pre':
|
| 127 |
+
trunc = sequence[-maxlen:]
|
| 128 |
+
else:
|
| 129 |
+
trunc = sequence[:maxlen]
|
| 130 |
+
trunc = np.asarray(trunc, dtype=dtype)
|
| 131 |
+
if padding == 'post':
|
| 132 |
+
x[:len(trunc)] = trunc
|
| 133 |
+
else:
|
| 134 |
+
x[-len(trunc):] = trunc
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TransformerConnectionError(ValueError):
|
| 139 |
+
def __init__(self):
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def retry(f):
|
| 144 |
+
@wraps(f)
|
| 145 |
+
def decorated(*args, **kwargs):
|
| 146 |
+
count = 5
|
| 147 |
+
while count:
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
return f(*args, **kwargs)
|
| 151 |
+
except (
|
| 152 |
+
TransformerConnectionError,
|
| 153 |
+
requests.exceptions.RequestException,
|
| 154 |
+
requests.exceptions.ConnectionError,
|
| 155 |
+
requests.exceptions.HTTPError,
|
| 156 |
+
requests.exceptions.ConnectTimeout,
|
| 157 |
+
requests.exceptions.ProxyError,
|
| 158 |
+
requests.exceptions.SSLError,
|
| 159 |
+
requests.exceptions.BaseHTTPError,
|
| 160 |
+
) as e:
|
| 161 |
+
print(colored('Training Exception: {}, will retry later'.format(e)))
|
| 162 |
+
time.sleep(60)
|
| 163 |
+
count -= 1
|
| 164 |
+
|
| 165 |
+
return decorated
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def save_json(dic, save_path):
|
| 169 |
+
if isinstance(dic, str):
|
| 170 |
+
dic = eval(dic)
|
| 171 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 172 |
+
# f.write(str(dict))
|
| 173 |
+
str_ = json.dumps(dic, ensure_ascii=False)
|
| 174 |
+
f.write(str_)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_json(save_path):
|
| 178 |
+
with open(save_path, 'r', encoding='utf-8') as f:
|
| 179 |
+
data = f.readline().strip()
|
| 180 |
+
print(type(data), data)
|
| 181 |
+
dic = json.loads(data)
|
| 182 |
+
return dic
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def init_optimizer(optimizer):
|
| 186 |
+
optimizers = {
|
| 187 |
+
'adadelta': torch.optim.Adadelta, # default lr=1.0
|
| 188 |
+
'adagrad': torch.optim.Adagrad, # default lr=0.01
|
| 189 |
+
'adam': torch.optim.Adam, # default lr=0.001
|
| 190 |
+
'adamax': torch.optim.Adamax, # default lr=0.002
|
| 191 |
+
'asgd': torch.optim.ASGD, # default lr=0.01
|
| 192 |
+
'rmsprop': torch.optim.RMSprop, # default lr=0.01
|
| 193 |
+
'sgd': torch.optim.SGD,
|
| 194 |
+
'adamw': torch.optim.AdamW,
|
| 195 |
+
torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
|
| 196 |
+
torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
|
| 197 |
+
torch.optim.Adam: torch.optim.Adam, # default lr=0.001
|
| 198 |
+
torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002
|
| 199 |
+
torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01
|
| 200 |
+
torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01
|
| 201 |
+
torch.optim.SGD: torch.optim.SGD,
|
| 202 |
+
torch.optim.AdamW: torch.optim.AdamW,
|
| 203 |
+
}
|
| 204 |
+
if optimizer in optimizers:
|
| 205 |
+
return optimizers[optimizer]
|
| 206 |
+
elif hasattr(torch.optim, optimizer.__name__):
|
| 207 |
+
return optimizer
|
| 208 |
+
else:
|
| 209 |
+
raise KeyError('Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer'.format(optimizer))
|
anonymous_demo/utils/logger.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import termcolor
|
| 7 |
+
|
| 8 |
+
today = time.strftime('%Y%m%d %H%M%S', time.localtime(time.time()))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_logger(log_path, log_name='', log_type='training_log'):
|
| 12 |
+
if not log_path:
|
| 13 |
+
log_dir = os.path.join(log_path, "logs")
|
| 14 |
+
else:
|
| 15 |
+
log_dir = os.path.join('.', "logs")
|
| 16 |
+
|
| 17 |
+
full_path = os.path.join(log_dir, log_name + '_' + today)
|
| 18 |
+
if not os.path.exists(full_path):
|
| 19 |
+
os.makedirs(full_path)
|
| 20 |
+
log_path = os.path.join(full_path, "{}.log".format(log_type))
|
| 21 |
+
logger = logging.getLogger(log_name)
|
| 22 |
+
if not logger.handlers:
|
| 23 |
+
formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
|
| 24 |
+
|
| 25 |
+
file_handler = logging.FileHandler(log_path, encoding="utf8")
|
| 26 |
+
file_handler.setFormatter(formatter)
|
| 27 |
+
file_handler.setLevel(logging.INFO)
|
| 28 |
+
|
| 29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 30 |
+
console_handler.formatter = formatter
|
| 31 |
+
console_handler.setLevel(logging.INFO)
|
| 32 |
+
|
| 33 |
+
logger.addHandler(file_handler)
|
| 34 |
+
logger.addHandler(console_handler)
|
| 35 |
+
|
| 36 |
+
logger.setLevel(logging.INFO)
|
| 37 |
+
|
| 38 |
+
return logger
|
app.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import zipfile
|
| 4 |
+
from difflib import Differ
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import nltk
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from findfile import find_files
|
| 10 |
+
|
| 11 |
+
from anonymous_demo import TADCheckpointManager
|
| 12 |
+
from textattack import Attacker
|
| 13 |
+
from textattack.attack_recipes import BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018
|
| 14 |
+
from textattack.attack_results import SuccessfulAttackResult
|
| 15 |
+
from textattack.datasets import Dataset
|
| 16 |
+
from textattack.models.wrappers import HuggingFaceModelWrapper
|
| 17 |
+
|
| 18 |
+
z = zipfile.ZipFile('checkpoints.zip', 'r')
|
| 19 |
+
z.extractall(os.getcwd())
|
| 20 |
+
|
| 21 |
+
class ModelWrapper(HuggingFaceModelWrapper):
|
| 22 |
+
def __init__(self, model):
|
| 23 |
+
self.model = model # pipeline = pipeline
|
| 24 |
+
|
| 25 |
+
def __call__(self, text_inputs, **kwargs):
|
| 26 |
+
outputs = []
|
| 27 |
+
for text_input in text_inputs:
|
| 28 |
+
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
|
| 29 |
+
outputs.append(raw_outputs['probs'])
|
| 30 |
+
return outputs
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SentAttacker:
|
| 34 |
+
|
| 35 |
+
def __init__(self, model, recipe_class=BAEGarg2019):
|
| 36 |
+
model = model
|
| 37 |
+
model_wrapper = ModelWrapper(model)
|
| 38 |
+
|
| 39 |
+
recipe = recipe_class.build(model_wrapper)
|
| 40 |
+
# WordNet defaults to english. Set the default language to French ('fra')
|
| 41 |
+
|
| 42 |
+
# recipe.transformation.language = "en"
|
| 43 |
+
|
| 44 |
+
_dataset = [('', 0)]
|
| 45 |
+
_dataset = Dataset(_dataset)
|
| 46 |
+
|
| 47 |
+
self.attacker = Attacker(recipe, _dataset)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def diff_texts(text1, text2):
|
| 51 |
+
d = Differ()
|
| 52 |
+
return [
|
| 53 |
+
(token[2:], token[0] if token[0] != " " else None)
|
| 54 |
+
for token in d.compare(text1, text2)
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_ensembled_tad_results(results):
|
| 59 |
+
target_dict = {}
|
| 60 |
+
for r in results:
|
| 61 |
+
target_dict[r['label']] = target_dict.get(r['label']) + 1 if r['label'] in target_dict else 1
|
| 62 |
+
|
| 63 |
+
return dict(zip(target_dict.values(), target_dict.keys()))[max(target_dict.values())]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
nltk.download('omw-1.4')
|
| 67 |
+
|
| 68 |
+
sent_attackers = {}
|
| 69 |
+
tad_classifiers = {}
|
| 70 |
+
|
| 71 |
+
attack_recipes = {
|
| 72 |
+
'bae': BAEGarg2019,
|
| 73 |
+
'pwws': PWWSRen2019,
|
| 74 |
+
'textfooler': TextFoolerJin2019,
|
| 75 |
+
'pso': PSOZang2020,
|
| 76 |
+
'iga': IGAWang2019,
|
| 77 |
+
'GA': GeneticAlgorithmAlzantot2018,
|
| 78 |
+
'wordbugger': DeepWordBugGao2018,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
for attacker in [
|
| 82 |
+
'pwws',
|
| 83 |
+
'bae',
|
| 84 |
+
'textfooler'
|
| 85 |
+
]:
|
| 86 |
+
for dataset in [
|
| 87 |
+
'agnews10k',
|
| 88 |
+
'amazon',
|
| 89 |
+
'sst2',
|
| 90 |
+
]:
|
| 91 |
+
if 'tad-{}'.format(dataset) not in tad_classifiers:
|
| 92 |
+
tad_classifiers['tad-{}'.format(dataset)] = TADCheckpointManager.get_tad_text_classifier('tad-{}'.format(dataset).upper())
|
| 93 |
+
|
| 94 |
+
sent_attackers['tad-{}{}'.format(dataset, attacker)] = SentAttacker(tad_classifiers['tad-{}'.format(dataset)], attack_recipes[attacker])
|
| 95 |
+
tad_classifiers['tad-{}'.format(dataset)].sent_attacker = sent_attackers['tad-{}pwws'.format(dataset)]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_a_sst2_example():
|
| 99 |
+
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
|
| 100 |
+
|
| 101 |
+
dataset_file = {'train': [], 'test': [], 'valid': []}
|
| 102 |
+
dataset = 'sst2'
|
| 103 |
+
search_path = './'
|
| 104 |
+
task = 'text_defense'
|
| 105 |
+
dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
|
| 106 |
+
|
| 107 |
+
for dat_type in [
|
| 108 |
+
'test'
|
| 109 |
+
]:
|
| 110 |
+
data = []
|
| 111 |
+
label_set = set()
|
| 112 |
+
for data_file in dataset_file[dat_type]:
|
| 113 |
+
|
| 114 |
+
with open(data_file, mode='r', encoding='utf8') as fin:
|
| 115 |
+
lines = fin.readlines()
|
| 116 |
+
for line in lines:
|
| 117 |
+
text, label = line.split('$LABEL$')
|
| 118 |
+
text = text.strip()
|
| 119 |
+
label = int(label.strip())
|
| 120 |
+
data.append((text, label))
|
| 121 |
+
label_set.add(label)
|
| 122 |
+
return data[random.randint(0, len(data))]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_a_agnews_example():
|
| 126 |
+
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
|
| 127 |
+
|
| 128 |
+
dataset_file = {'train': [], 'test': [], 'valid': []}
|
| 129 |
+
dataset = 'agnews'
|
| 130 |
+
search_path = './'
|
| 131 |
+
task = 'text_defense'
|
| 132 |
+
dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
|
| 133 |
+
for dat_type in [
|
| 134 |
+
'test'
|
| 135 |
+
]:
|
| 136 |
+
data = []
|
| 137 |
+
label_set = set()
|
| 138 |
+
for data_file in dataset_file[dat_type]:
|
| 139 |
+
|
| 140 |
+
with open(data_file, mode='r', encoding='utf8') as fin:
|
| 141 |
+
lines = fin.readlines()
|
| 142 |
+
for line in lines:
|
| 143 |
+
text, label = line.split('$LABEL$')
|
| 144 |
+
text = text.strip()
|
| 145 |
+
label = int(label.strip())
|
| 146 |
+
data.append((text, label))
|
| 147 |
+
label_set.add(label)
|
| 148 |
+
return data[random.randint(0, len(data))]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_a_amazon_example():
|
| 152 |
+
filter_key_words = ['.py', '.md', 'readme', 'log', 'result', 'zip', '.state_dict', '.model', '.png', 'acc_', 'f1_', '.origin', '.adv', '.csv']
|
| 153 |
+
|
| 154 |
+
dataset_file = {'train': [], 'test': [], 'valid': []}
|
| 155 |
+
dataset = 'amazon'
|
| 156 |
+
search_path = './'
|
| 157 |
+
task = 'text_defense'
|
| 158 |
+
dataset_file['test'] += find_files(search_path, [dataset, 'test', task], exclude_key=['.adv', '.org', '.defense', '.inference', 'train.'] + filter_key_words)
|
| 159 |
+
|
| 160 |
+
for dat_type in [
|
| 161 |
+
'test'
|
| 162 |
+
]:
|
| 163 |
+
data = []
|
| 164 |
+
label_set = set()
|
| 165 |
+
for data_file in dataset_file[dat_type]:
|
| 166 |
+
|
| 167 |
+
with open(data_file, mode='r', encoding='utf8') as fin:
|
| 168 |
+
lines = fin.readlines()
|
| 169 |
+
for line in lines:
|
| 170 |
+
text, label = line.split('$LABEL$')
|
| 171 |
+
text = text.strip()
|
| 172 |
+
label = int(label.strip())
|
| 173 |
+
data.append((text, label))
|
| 174 |
+
label_set.add(label)
|
| 175 |
+
return data[random.randint(0, len(data))]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def generate_adversarial_example(dataset, attacker, text=None, label=None):
|
| 179 |
+
if not text:
|
| 180 |
+
if 'agnews' in dataset.lower():
|
| 181 |
+
text, label = get_a_agnews_example()
|
| 182 |
+
elif 'sst2' in dataset.lower():
|
| 183 |
+
text, label = get_a_sst2_example()
|
| 184 |
+
elif 'amazon' in dataset.lower():
|
| 185 |
+
text, label = get_a_amazon_example()
|
| 186 |
+
|
| 187 |
+
result = None
|
| 188 |
+
attack_result = sent_attackers['tad-{}{}'.format(dataset.lower(), attacker.lower())].attacker.simple_attack(text, int(label))
|
| 189 |
+
if isinstance(attack_result, SuccessfulAttackResult):
|
| 190 |
+
|
| 191 |
+
if (attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output) and (attack_result.original_result.output == attack_result.original_result.ground_truth_output):
|
| 192 |
+
# with defense
|
| 193 |
+
result = tad_classifiers['tad-{}'.format(dataset.lower())].infer(
|
| 194 |
+
attack_result.perturbed_result.attacked_text.text + '!ref!{},{},{}'.format(attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output),
|
| 195 |
+
print_result=True,
|
| 196 |
+
defense='pwws',
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if result:
|
| 200 |
+
classification_df = {}
|
| 201 |
+
classification_df['pred_label'] = result['label']
|
| 202 |
+
classification_df['confidence'] = round(result['confidence'], 3)
|
| 203 |
+
classification_df['is_correct'] = result['ref_label_check']
|
| 204 |
+
classification_df['is_repaired'] = result['is_fixed']
|
| 205 |
+
|
| 206 |
+
advdetection_df = {}
|
| 207 |
+
if result['is_adv_label'] != '0':
|
| 208 |
+
advdetection_df['is_adversary'] = result['is_adv_label']
|
| 209 |
+
advdetection_df['perturbed_label'] = result['perturbed_label']
|
| 210 |
+
advdetection_df['confidence'] = round(result['is_adv_confidence'], 3)
|
| 211 |
+
# advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
|
| 212 |
+
# advdetection_df['is_correct'] = result['ref_is_adv_check']
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
return generate_adversarial_example(dataset, attacker)
|
| 216 |
+
|
| 217 |
+
return (text,
|
| 218 |
+
label,
|
| 219 |
+
attack_result.perturbed_result.attacked_text.text,
|
| 220 |
+
diff_texts(text, attack_result.perturbed_result.attacked_text.text),
|
| 221 |
+
diff_texts(text, result['restored_text']),
|
| 222 |
+
attack_result.perturbed_result.output,
|
| 223 |
+
pd.DataFrame(classification_df, index=[0]),
|
| 224 |
+
pd.DataFrame(advdetection_df, index=[0])
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
demo = gr.Blocks()
|
| 229 |
+
|
| 230 |
+
with demo:
|
| 231 |
+
with gr.Row():
|
| 232 |
+
with gr.Column():
|
| 233 |
+
input_dataset = gr.Radio(choices=['SST2', 'AGNews10K', 'Amazon'], value='Amazon', label="Dataset")
|
| 234 |
+
input_attacker = gr.Radio(choices=['BAE', 'PWWS', 'TextFooler'], value='TextFooler', label="Attacker")
|
| 235 |
+
input_sentence = gr.Textbox(placeholder='Randomly choose a example from testing set if this box is blank', label="Sentence")
|
| 236 |
+
input_label = gr.Textbox(placeholder='original label ... ', label="Original Label")
|
| 237 |
+
|
| 238 |
+
gr.Markdown("Original Example")
|
| 239 |
+
|
| 240 |
+
output_origin_example = gr.Textbox(label="Original Example")
|
| 241 |
+
output_original_label = gr.Textbox(label="Original Label")
|
| 242 |
+
|
| 243 |
+
gr.Markdown("Adversarial Example")
|
| 244 |
+
output_adv_example = gr.Textbox(label="Adversarial Example")
|
| 245 |
+
output_adv_label = gr.Textbox(label="Perturbed Label")
|
| 246 |
+
|
| 247 |
+
gr.Markdown('This demo is deployed on a CPU device so it may take a long time to execute. Please be patient.')
|
| 248 |
+
button_gen = gr.Button("Click Here to Generate an Adversary and Run Adversary Detection & Repair")
|
| 249 |
+
|
| 250 |
+
# Right column (outputs)
|
| 251 |
+
with gr.Column():
|
| 252 |
+
gr.Markdown("Example Difference")
|
| 253 |
+
adv_text_diff = gr.HighlightedText(label="Adversarial Example Difference", combine_adjacent=True)
|
| 254 |
+
restored_text_diff = gr.HighlightedText(label="Restored Example Difference", combine_adjacent=True)
|
| 255 |
+
|
| 256 |
+
output_is_adv_df = gr.DataFrame(label="Adversary Prediction")
|
| 257 |
+
output_df = gr.DataFrame(label="Standard Classification Prediction")
|
| 258 |
+
|
| 259 |
+
# Bind functions to buttons
|
| 260 |
+
button_gen.click(fn=generate_adversarial_example,
|
| 261 |
+
inputs=[input_dataset, input_attacker, input_sentence, input_label],
|
| 262 |
+
outputs=[output_origin_example,
|
| 263 |
+
output_original_label,
|
| 264 |
+
output_adv_example,
|
| 265 |
+
adv_text_diff,
|
| 266 |
+
restored_text_diff,
|
| 267 |
+
output_adv_label,
|
| 268 |
+
output_df,
|
| 269 |
+
output_is_adv_df])
|
| 270 |
+
|
| 271 |
+
demo.launch()
|
checkpoints.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a5452cd89dcd3132d616cc81e2a1b063efa7d11e5798719b0779715b1c6edeb
|
| 3 |
+
size 1846862527
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
findfile>=1.7.9.8
|
| 2 |
+
autocuda>=0.11
|
| 3 |
+
metric-visualizer>=0.5.5
|
| 4 |
+
boostaug>=2.2.3
|
| 5 |
+
spacy
|
| 6 |
+
networkx
|
| 7 |
+
seqeval
|
| 8 |
+
update-checker
|
| 9 |
+
typing_extensions
|
| 10 |
+
tqdm
|
| 11 |
+
pytorch_warmup
|
| 12 |
+
termcolor
|
| 13 |
+
gitpython
|
| 14 |
+
gdown>=4.4.0
|
| 15 |
+
transformers>4.20.0
|
| 16 |
+
torch>1.0.0
|
| 17 |
+
sentencepiece
|
| 18 |
+
tensorflow_text
|
| 19 |
+
textattack
|
text_defense/201.SST2/stsa.binary.dev.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/201.SST2/stsa.binary.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/201.SST2/stsa.binary.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/204.AGNews10K/AGNews10K.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/204.AGNews10K/AGNews10K.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/204.AGNews10K/AGNews10K.valid.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
textattack/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Welcome to the API references for TextAttack!
|
| 2 |
+
|
| 3 |
+
What is TextAttack?
|
| 4 |
+
|
| 5 |
+
`TextAttack <https://github.com/QData/TextAttack>`__ is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
|
| 6 |
+
|
| 7 |
+
TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It's also useful for NLP model training, adversarial training, and data augmentation.
|
| 8 |
+
|
| 9 |
+
TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
|
| 10 |
+
"""
|
| 11 |
+
from .attack_args import AttackArgs, CommandLineAttackArgs
|
| 12 |
+
from .augment_args import AugmenterArgs
|
| 13 |
+
from .dataset_args import DatasetArgs
|
| 14 |
+
from .model_args import ModelArgs
|
| 15 |
+
from .training_args import TrainingArgs, CommandLineTrainingArgs
|
| 16 |
+
from .attack import Attack
|
| 17 |
+
from .attacker import Attacker
|
| 18 |
+
from .trainer import Trainer
|
| 19 |
+
from .metrics import Metric
|
| 20 |
+
|
| 21 |
+
from . import (
|
| 22 |
+
attack_recipes,
|
| 23 |
+
attack_results,
|
| 24 |
+
augmentation,
|
| 25 |
+
commands,
|
| 26 |
+
constraints,
|
| 27 |
+
datasets,
|
| 28 |
+
goal_function_results,
|
| 29 |
+
goal_functions,
|
| 30 |
+
loggers,
|
| 31 |
+
metrics,
|
| 32 |
+
models,
|
| 33 |
+
search_methods,
|
| 34 |
+
shared,
|
| 35 |
+
transformations,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
name = "textattack"
|
textattack/__main__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
if __name__ == "__main__":
|
| 4 |
+
import textattack
|
| 5 |
+
|
| 6 |
+
textattack.commands.textattack_cli.main()
|
textattack/attack.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Attack Class
|
| 3 |
+
============
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
|
| 9 |
+
import lru
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import textattack
|
| 13 |
+
from textattack.attack_results import (
|
| 14 |
+
FailedAttackResult,
|
| 15 |
+
MaximizedAttackResult,
|
| 16 |
+
SkippedAttackResult,
|
| 17 |
+
SuccessfulAttackResult,
|
| 18 |
+
)
|
| 19 |
+
from textattack.constraints import Constraint, PreTransformationConstraint
|
| 20 |
+
from textattack.goal_function_results import GoalFunctionResultStatus
|
| 21 |
+
from textattack.goal_functions import GoalFunction
|
| 22 |
+
from textattack.models.wrappers import ModelWrapper
|
| 23 |
+
from textattack.search_methods import SearchMethod
|
| 24 |
+
from textattack.shared import AttackedText, utils
|
| 25 |
+
from textattack.transformations import CompositeTransformation, Transformation
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Attack:
|
| 29 |
+
"""An attack generates adversarial examples on text.
|
| 30 |
+
|
| 31 |
+
An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
goal_function (:class:`~textattack.goal_functions.GoalFunction`):
|
| 35 |
+
A function for determining how well a perturbation is doing at achieving the attack's goal.
|
| 36 |
+
constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`):
|
| 37 |
+
A list of constraints to add to the attack, defining which perturbations are valid.
|
| 38 |
+
transformation (:class:`~textattack.transformations.Transformation`):
|
| 39 |
+
The transformation applied at each step of the attack.
|
| 40 |
+
search_method (:class:`~textattack.search_methods.SearchMethod`):
|
| 41 |
+
The method for exploring the search space of possible perturbations
|
| 42 |
+
transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
|
| 43 |
+
The number of items to keep in the transformations cache
|
| 44 |
+
constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
|
| 45 |
+
The number of items to keep in the constraints cache
|
| 46 |
+
|
| 47 |
+
Example::
|
| 48 |
+
|
| 49 |
+
>>> import textattack
|
| 50 |
+
>>> import transformers
|
| 51 |
+
|
| 52 |
+
>>> # Load model, tokenizer, and model_wrapper
|
| 53 |
+
>>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
|
| 54 |
+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
|
| 55 |
+
>>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
|
| 56 |
+
|
| 57 |
+
>>> # Construct our four components for `Attack`
|
| 58 |
+
>>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
|
| 59 |
+
>>> from textattack.constraints.semantics import WordEmbeddingDistance
|
| 60 |
+
|
| 61 |
+
>>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
|
| 62 |
+
>>> constraints = [
|
| 63 |
+
... RepeatModification(),
|
| 64 |
+
... StopwordModification()
|
| 65 |
+
... WordEmbeddingDistance(min_cos_sim=0.9)
|
| 66 |
+
... ]
|
| 67 |
+
>>> transformation = WordSwapEmbedding(max_candidates=50)
|
| 68 |
+
>>> search_method = GreedyWordSwapWIR(wir_method="delete")
|
| 69 |
+
|
| 70 |
+
>>> # Construct the actual attack
|
| 71 |
+
>>> attack = Attack(goal_function, constraints, transformation, search_method)
|
| 72 |
+
|
| 73 |
+
>>> input_text = "I really enjoyed the new movie that came out last month."
|
| 74 |
+
>>> label = 1 #Positive
|
| 75 |
+
>>> attack_result = attack.attack(input_text, label)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
goal_function: GoalFunction,
|
| 81 |
+
constraints: List[Union[Constraint, PreTransformationConstraint]],
|
| 82 |
+
transformation: Transformation,
|
| 83 |
+
search_method: SearchMethod,
|
| 84 |
+
transformation_cache_size=2**15,
|
| 85 |
+
constraint_cache_size=2**15,
|
| 86 |
+
):
|
| 87 |
+
"""Initialize an attack object.
|
| 88 |
+
|
| 89 |
+
Attacks can be run multiple times.
|
| 90 |
+
"""
|
| 91 |
+
assert isinstance(
|
| 92 |
+
goal_function, GoalFunction
|
| 93 |
+
), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`."
|
| 94 |
+
assert isinstance(
|
| 95 |
+
constraints, list
|
| 96 |
+
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
|
| 97 |
+
for c in constraints:
|
| 98 |
+
assert isinstance(
|
| 99 |
+
c, (Constraint, PreTransformationConstraint)
|
| 100 |
+
), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
|
| 101 |
+
assert isinstance(
|
| 102 |
+
transformation, Transformation
|
| 103 |
+
), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`."
|
| 104 |
+
assert isinstance(
|
| 105 |
+
search_method, SearchMethod
|
| 106 |
+
), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`."
|
| 107 |
+
|
| 108 |
+
self.goal_function = goal_function
|
| 109 |
+
self.search_method = search_method
|
| 110 |
+
self.transformation = transformation
|
| 111 |
+
self.is_black_box = (
|
| 112 |
+
getattr(transformation, "is_black_box", True) and search_method.is_black_box
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not self.search_method.check_transformation_compatibility(
|
| 116 |
+
self.transformation
|
| 117 |
+
):
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.constraints = []
|
| 123 |
+
self.pre_transformation_constraints = []
|
| 124 |
+
for constraint in constraints:
|
| 125 |
+
if isinstance(
|
| 126 |
+
constraint,
|
| 127 |
+
textattack.constraints.PreTransformationConstraint,
|
| 128 |
+
):
|
| 129 |
+
self.pre_transformation_constraints.append(constraint)
|
| 130 |
+
else:
|
| 131 |
+
self.constraints.append(constraint)
|
| 132 |
+
|
| 133 |
+
# Check if we can use transformation cache for our transformation.
|
| 134 |
+
if not self.transformation.deterministic:
|
| 135 |
+
self.use_transformation_cache = False
|
| 136 |
+
elif isinstance(self.transformation, CompositeTransformation):
|
| 137 |
+
self.use_transformation_cache = True
|
| 138 |
+
for t in self.transformation.transformations:
|
| 139 |
+
if not t.deterministic:
|
| 140 |
+
self.use_transformation_cache = False
|
| 141 |
+
break
|
| 142 |
+
else:
|
| 143 |
+
self.use_transformation_cache = True
|
| 144 |
+
self.transformation_cache_size = transformation_cache_size
|
| 145 |
+
self.transformation_cache = lru.LRU(transformation_cache_size)
|
| 146 |
+
|
| 147 |
+
self.constraint_cache_size = constraint_cache_size
|
| 148 |
+
self.constraints_cache = lru.LRU(constraint_cache_size)
|
| 149 |
+
|
| 150 |
+
# Give search method access to functions for getting transformations and evaluating them
|
| 151 |
+
self.search_method.get_transformations = self.get_transformations
|
| 152 |
+
# Give search method access to self.goal_function for model query count, etc.
|
| 153 |
+
self.search_method.goal_function = self.goal_function
|
| 154 |
+
# The search method only needs access to the first argument. The second is only used
|
| 155 |
+
# by the attack class when checking whether to skip the sample
|
| 156 |
+
self.search_method.get_goal_results = self.goal_function.get_results
|
| 157 |
+
|
| 158 |
+
# Give search method access to get indices which need to be ordered / searched
|
| 159 |
+
self.search_method.get_indices_to_order = self.get_indices_to_order
|
| 160 |
+
|
| 161 |
+
self.search_method.filter_transformations = self.filter_transformations
|
| 162 |
+
|
| 163 |
+
def clear_cache(self, recursive=True):
|
| 164 |
+
self.constraints_cache.clear()
|
| 165 |
+
if self.use_transformation_cache:
|
| 166 |
+
self.transformation_cache.clear()
|
| 167 |
+
if recursive:
|
| 168 |
+
self.goal_function.clear_cache()
|
| 169 |
+
for constraint in self.constraints:
|
| 170 |
+
if hasattr(constraint, "clear_cache"):
|
| 171 |
+
constraint.clear_cache()
|
| 172 |
+
|
| 173 |
+
def cpu_(self):
|
| 174 |
+
"""Move any `torch.nn.Module` models that are part of Attack to CPU."""
|
| 175 |
+
visited = set()
|
| 176 |
+
|
| 177 |
+
def to_cpu(obj):
|
| 178 |
+
visited.add(id(obj))
|
| 179 |
+
if isinstance(obj, torch.nn.Module):
|
| 180 |
+
obj.cpu()
|
| 181 |
+
elif isinstance(
|
| 182 |
+
obj,
|
| 183 |
+
(
|
| 184 |
+
Attack,
|
| 185 |
+
GoalFunction,
|
| 186 |
+
Transformation,
|
| 187 |
+
SearchMethod,
|
| 188 |
+
Constraint,
|
| 189 |
+
PreTransformationConstraint,
|
| 190 |
+
ModelWrapper,
|
| 191 |
+
),
|
| 192 |
+
):
|
| 193 |
+
for key in obj.__dict__:
|
| 194 |
+
s_obj = obj.__dict__[key]
|
| 195 |
+
if id(s_obj) not in visited:
|
| 196 |
+
to_cpu(s_obj)
|
| 197 |
+
elif isinstance(obj, (list, tuple)):
|
| 198 |
+
for item in obj:
|
| 199 |
+
if id(item) not in visited and isinstance(
|
| 200 |
+
item, (Transformation, Constraint, PreTransformationConstraint)
|
| 201 |
+
):
|
| 202 |
+
to_cpu(item)
|
| 203 |
+
|
| 204 |
+
to_cpu(self)
|
| 205 |
+
|
| 206 |
+
def cuda_(self):
|
| 207 |
+
"""Move any `torch.nn.Module` models that are part of Attack to GPU."""
|
| 208 |
+
visited = set()
|
| 209 |
+
|
| 210 |
+
def to_cuda(obj):
|
| 211 |
+
visited.add(id(obj))
|
| 212 |
+
if isinstance(obj, torch.nn.Module):
|
| 213 |
+
obj.to(textattack.shared.utils.device)
|
| 214 |
+
elif isinstance(
|
| 215 |
+
obj,
|
| 216 |
+
(
|
| 217 |
+
Attack,
|
| 218 |
+
GoalFunction,
|
| 219 |
+
Transformation,
|
| 220 |
+
SearchMethod,
|
| 221 |
+
Constraint,
|
| 222 |
+
PreTransformationConstraint,
|
| 223 |
+
ModelWrapper,
|
| 224 |
+
),
|
| 225 |
+
):
|
| 226 |
+
for key in obj.__dict__:
|
| 227 |
+
s_obj = obj.__dict__[key]
|
| 228 |
+
if id(s_obj) not in visited:
|
| 229 |
+
to_cuda(s_obj)
|
| 230 |
+
elif isinstance(obj, (list, tuple)):
|
| 231 |
+
for item in obj:
|
| 232 |
+
if id(item) not in visited and isinstance(
|
| 233 |
+
item, (Transformation, Constraint, PreTransformationConstraint)
|
| 234 |
+
):
|
| 235 |
+
to_cuda(item)
|
| 236 |
+
|
| 237 |
+
to_cuda(self)
|
| 238 |
+
|
| 239 |
+
def get_indices_to_order(self, current_text, **kwargs):
|
| 240 |
+
"""Applies ``pre_transformation_constraints`` to ``text`` to get all
|
| 241 |
+
the indices that can be used to search and order.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered.
|
| 245 |
+
Returns:
|
| 246 |
+
The length and the filtered list of indices which search methods can use to search/order.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
indices_to_order = self.transformation(
|
| 250 |
+
current_text,
|
| 251 |
+
pre_transformation_constraints=self.pre_transformation_constraints,
|
| 252 |
+
return_indices=True,
|
| 253 |
+
**kwargs,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
len_text = len(indices_to_order)
|
| 257 |
+
|
| 258 |
+
# Convert indices_to_order to list for easier shuffling later
|
| 259 |
+
return len_text, list(indices_to_order)
|
| 260 |
+
|
| 261 |
+
def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
|
| 262 |
+
"""Applies ``self.transformation`` to ``text``, then filters the list
|
| 263 |
+
of possible transformations through the applicable constraints.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
current_text: The current ``AttackedText`` on which to perform the transformations.
|
| 267 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 268 |
+
Returns:
|
| 269 |
+
A filtered list of transformations where each transformation matches the constraints
|
| 270 |
+
"""
|
| 271 |
+
transformed_texts = self.transformation(
|
| 272 |
+
current_text,
|
| 273 |
+
pre_transformation_constraints=self.pre_transformation_constraints,
|
| 274 |
+
**kwargs,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return transformed_texts
|
| 278 |
+
|
| 279 |
+
def get_transformations(self, current_text, original_text=None, **kwargs):
|
| 280 |
+
"""Applies ``self.transformation`` to ``text``, then filters the list
|
| 281 |
+
of possible transformations through the applicable constraints.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
current_text: The current ``AttackedText`` on which to perform the transformations.
|
| 285 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 286 |
+
Returns:
|
| 287 |
+
A filtered list of transformations where each transformation matches the constraints
|
| 288 |
+
"""
|
| 289 |
+
if not self.transformation:
|
| 290 |
+
raise RuntimeError(
|
| 291 |
+
"Cannot call `get_transformations` without a transformation."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if self.use_transformation_cache:
|
| 295 |
+
cache_key = tuple([current_text] + sorted(kwargs.items()))
|
| 296 |
+
if utils.hashable(cache_key) and cache_key in self.transformation_cache:
|
| 297 |
+
# promote transformed_text to the top of the LRU cache
|
| 298 |
+
self.transformation_cache[cache_key] = self.transformation_cache[
|
| 299 |
+
cache_key
|
| 300 |
+
]
|
| 301 |
+
transformed_texts = list(self.transformation_cache[cache_key])
|
| 302 |
+
else:
|
| 303 |
+
transformed_texts = self._get_transformations_uncached(
|
| 304 |
+
current_text, original_text, **kwargs
|
| 305 |
+
)
|
| 306 |
+
if utils.hashable(cache_key):
|
| 307 |
+
self.transformation_cache[cache_key] = tuple(transformed_texts)
|
| 308 |
+
else:
|
| 309 |
+
transformed_texts = self._get_transformations_uncached(
|
| 310 |
+
current_text, original_text, **kwargs
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return self.filter_transformations(
|
| 314 |
+
transformed_texts, current_text, original_text
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
def _filter_transformations_uncached(
|
| 318 |
+
self, transformed_texts, current_text, original_text=None
|
| 319 |
+
):
|
| 320 |
+
"""Filters a list of potential transformed texts based on
|
| 321 |
+
``self.constraints``
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
| 325 |
+
current_text: The current ``AttackedText`` on which the transformation was applied.
|
| 326 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 327 |
+
"""
|
| 328 |
+
filtered_texts = transformed_texts[:]
|
| 329 |
+
for C in self.constraints:
|
| 330 |
+
if len(filtered_texts) == 0:
|
| 331 |
+
break
|
| 332 |
+
if C.compare_against_original:
|
| 333 |
+
if not original_text:
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
filtered_texts = C.call_many(filtered_texts, original_text)
|
| 339 |
+
else:
|
| 340 |
+
filtered_texts = C.call_many(filtered_texts, current_text)
|
| 341 |
+
# Default to false for all original transformations.
|
| 342 |
+
for original_transformed_text in transformed_texts:
|
| 343 |
+
self.constraints_cache[(current_text, original_transformed_text)] = False
|
| 344 |
+
# Set unfiltered transformations to True in the cache.
|
| 345 |
+
for filtered_text in filtered_texts:
|
| 346 |
+
self.constraints_cache[(current_text, filtered_text)] = True
|
| 347 |
+
return filtered_texts
|
| 348 |
+
|
| 349 |
+
def filter_transformations(
|
| 350 |
+
self, transformed_texts, current_text, original_text=None
|
| 351 |
+
):
|
| 352 |
+
"""Filters a list of potential transformed texts based on
|
| 353 |
+
``self.constraints`` Utilizes an LRU cache to attempt to avoid
|
| 354 |
+
recomputing common transformations.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
|
| 358 |
+
current_text: The current ``AttackedText`` on which the transformation was applied.
|
| 359 |
+
original_text: The original ``AttackedText`` from which the attack started.
|
| 360 |
+
"""
|
| 361 |
+
# Remove any occurences of current_text in transformed_texts
|
| 362 |
+
transformed_texts = [
|
| 363 |
+
t for t in transformed_texts if t.text != current_text.text
|
| 364 |
+
]
|
| 365 |
+
# Populate cache with transformed_texts
|
| 366 |
+
uncached_texts = []
|
| 367 |
+
filtered_texts = []
|
| 368 |
+
for transformed_text in transformed_texts:
|
| 369 |
+
if (current_text, transformed_text) not in self.constraints_cache:
|
| 370 |
+
uncached_texts.append(transformed_text)
|
| 371 |
+
else:
|
| 372 |
+
# promote transformed_text to the top of the LRU cache
|
| 373 |
+
self.constraints_cache[
|
| 374 |
+
(current_text, transformed_text)
|
| 375 |
+
] = self.constraints_cache[(current_text, transformed_text)]
|
| 376 |
+
if self.constraints_cache[(current_text, transformed_text)]:
|
| 377 |
+
filtered_texts.append(transformed_text)
|
| 378 |
+
filtered_texts += self._filter_transformations_uncached(
|
| 379 |
+
uncached_texts, current_text, original_text=original_text
|
| 380 |
+
)
|
| 381 |
+
# Sort transformations to ensure order is preserved between runs
|
| 382 |
+
filtered_texts.sort(key=lambda t: t.text)
|
| 383 |
+
return filtered_texts
|
| 384 |
+
|
| 385 |
+
def _attack(self, initial_result):
|
| 386 |
+
"""Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
|
| 387 |
+
``initial_result``.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
initial_result: The initial ``GoalFunctionResult`` from which to perturb.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
A ``SuccessfulAttackResult``, ``FailedAttackResult``,
|
| 394 |
+
or ``MaximizedAttackResult``.
|
| 395 |
+
"""
|
| 396 |
+
final_result = self.search_method(initial_result)
|
| 397 |
+
self.clear_cache()
|
| 398 |
+
if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
|
| 399 |
+
result = SuccessfulAttackResult(
|
| 400 |
+
initial_result,
|
| 401 |
+
final_result,
|
| 402 |
+
)
|
| 403 |
+
elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
|
| 404 |
+
result = FailedAttackResult(
|
| 405 |
+
initial_result,
|
| 406 |
+
final_result,
|
| 407 |
+
)
|
| 408 |
+
elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
|
| 409 |
+
result = MaximizedAttackResult(
|
| 410 |
+
initial_result,
|
| 411 |
+
final_result,
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
|
| 415 |
+
return result
|
| 416 |
+
|
| 417 |
+
def attack(self, example, ground_truth_output):
|
| 418 |
+
"""Attack a single example.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`):
|
| 422 |
+
Example to attack. It can be a single string or an `OrderedDict` where
|
| 423 |
+
keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx.
|
| 424 |
+
Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input.
|
| 425 |
+
ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`):
|
| 426 |
+
Ground truth output of `example`.
|
| 427 |
+
For classification tasks, it should be an integer representing the ground truth label.
|
| 428 |
+
For regression tasks (e.g. STS), it should be the target value.
|
| 429 |
+
For seq2seq tasks (e.g. translation), it should be the target string.
|
| 430 |
+
Returns:
|
| 431 |
+
:class:`~textattack.attack_results.AttackResult` that represents the result of the attack.
|
| 432 |
+
"""
|
| 433 |
+
assert isinstance(
|
| 434 |
+
example, (str, OrderedDict, AttackedText)
|
| 435 |
+
), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
|
| 436 |
+
if isinstance(example, (str, OrderedDict)):
|
| 437 |
+
example = AttackedText(example)
|
| 438 |
+
|
| 439 |
+
assert isinstance(
|
| 440 |
+
ground_truth_output, (int, str)
|
| 441 |
+
), "`ground_truth_output` must either be `str` or `int`."
|
| 442 |
+
goal_function_result, _ = self.goal_function.init_attack_example(
|
| 443 |
+
example, ground_truth_output
|
| 444 |
+
)
|
| 445 |
+
if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
|
| 446 |
+
return SkippedAttackResult(goal_function_result)
|
| 447 |
+
else:
|
| 448 |
+
result = self._attack(goal_function_result)
|
| 449 |
+
return result
|
| 450 |
+
|
| 451 |
+
def __repr__(self):
|
| 452 |
+
"""Prints attack parameters in a human-readable string.
|
| 453 |
+
|
| 454 |
+
Inspired by the readability of printing PyTorch nn.Modules:
|
| 455 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
|
| 456 |
+
"""
|
| 457 |
+
main_str = "Attack" + "("
|
| 458 |
+
lines = []
|
| 459 |
+
|
| 460 |
+
lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2))
|
| 461 |
+
# self.goal_function
|
| 462 |
+
lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2))
|
| 463 |
+
# self.transformation
|
| 464 |
+
lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
|
| 465 |
+
# self.constraints
|
| 466 |
+
constraints_lines = []
|
| 467 |
+
constraints = self.constraints + self.pre_transformation_constraints
|
| 468 |
+
if len(constraints):
|
| 469 |
+
for i, constraint in enumerate(constraints):
|
| 470 |
+
constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
|
| 471 |
+
constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
|
| 472 |
+
else:
|
| 473 |
+
constraints_str = "None"
|
| 474 |
+
lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
|
| 475 |
+
# self.is_black_box
|
| 476 |
+
lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2))
|
| 477 |
+
main_str += "\n " + "\n ".join(lines) + "\n"
|
| 478 |
+
main_str += ")"
|
| 479 |
+
return main_str
|
| 480 |
+
|
| 481 |
+
def __getstate__(self):
|
| 482 |
+
state = self.__dict__.copy()
|
| 483 |
+
state["transformation_cache"] = None
|
| 484 |
+
state["constraints_cache"] = None
|
| 485 |
+
return state
|
| 486 |
+
|
| 487 |
+
def __setstate__(self, state):
|
| 488 |
+
self.__dict__ = state
|
| 489 |
+
self.transformation_cache = lru.LRU(self.transformation_cache_size)
|
| 490 |
+
self.constraints_cache = lru.LRU(self.constraint_cache_size)
|
| 491 |
+
|
| 492 |
+
__str__ = __repr__
|
textattack/attack_args.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AttackArgs Class
|
| 3 |
+
================
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from typing import Dict, Optional
|
| 12 |
+
|
| 13 |
+
import textattack
|
| 14 |
+
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
|
| 15 |
+
|
| 16 |
+
from .attack import Attack
|
| 17 |
+
from .dataset_args import DatasetArgs
|
| 18 |
+
from .model_args import ModelArgs
|
| 19 |
+
|
| 20 |
+
ATTACK_RECIPE_NAMES = {
|
| 21 |
+
"alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
|
| 22 |
+
"bae": "textattack.attack_recipes.BAEGarg2019",
|
| 23 |
+
"bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
|
| 24 |
+
"faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
|
| 25 |
+
"deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
|
| 26 |
+
"hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
|
| 27 |
+
"input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
|
| 28 |
+
"kuleshov": "textattack.attack_recipes.Kuleshov2017",
|
| 29 |
+
"morpheus": "textattack.attack_recipes.MorpheusTan2020",
|
| 30 |
+
"seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
|
| 31 |
+
"textbugger": "textattack.attack_recipes.TextBuggerLi2018",
|
| 32 |
+
"textfooler": "textattack.attack_recipes.TextFoolerJin2019",
|
| 33 |
+
"pwws": "textattack.attack_recipes.PWWSRen2019",
|
| 34 |
+
"iga": "textattack.attack_recipes.IGAWang2019",
|
| 35 |
+
"pruthi": "textattack.attack_recipes.Pruthi2019",
|
| 36 |
+
"pso": "textattack.attack_recipes.PSOZang2020",
|
| 37 |
+
"checklist": "textattack.attack_recipes.CheckList2020",
|
| 38 |
+
"clare": "textattack.attack_recipes.CLARE2020",
|
| 39 |
+
"a2t": "textattack.attack_recipes.A2TYoo2021",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
|
| 44 |
+
"random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
|
| 45 |
+
"word-deletion": "textattack.transformations.WordDeletion",
|
| 46 |
+
"word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
|
| 47 |
+
"word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
|
| 48 |
+
"word-swap-inflections": "textattack.transformations.WordSwapInflections",
|
| 49 |
+
"word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
|
| 50 |
+
"word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
|
| 51 |
+
"word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
|
| 52 |
+
"word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
|
| 53 |
+
"word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
|
| 54 |
+
"word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
|
| 55 |
+
"word-swap-hownet": "textattack.transformations.WordSwapHowNet",
|
| 56 |
+
"word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
|
| 61 |
+
"word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
CONSTRAINT_CLASS_NAMES = {
|
| 66 |
+
#
|
| 67 |
+
# Semantics constraints
|
| 68 |
+
#
|
| 69 |
+
"embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
|
| 70 |
+
"bert": "textattack.constraints.semantics.sentence_encoders.BERT",
|
| 71 |
+
"infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
|
| 72 |
+
"thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
|
| 73 |
+
"use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
|
| 74 |
+
"muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
|
| 75 |
+
"bert-score": "textattack.constraints.semantics.BERTScore",
|
| 76 |
+
#
|
| 77 |
+
# Grammaticality constraints
|
| 78 |
+
#
|
| 79 |
+
"lang-tool": "textattack.constraints.grammaticality.LanguageTool",
|
| 80 |
+
"part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
|
| 81 |
+
"goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
|
| 82 |
+
"gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
|
| 83 |
+
"learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
|
| 84 |
+
"cola": "textattack.constraints.grammaticality.COLA",
|
| 85 |
+
#
|
| 86 |
+
# Overlap constraints
|
| 87 |
+
#
|
| 88 |
+
"bleu": "textattack.constraints.overlap.BLEU",
|
| 89 |
+
"chrf": "textattack.constraints.overlap.chrF",
|
| 90 |
+
"edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
|
| 91 |
+
"meteor": "textattack.constraints.overlap.METEOR",
|
| 92 |
+
"max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
|
| 93 |
+
#
|
| 94 |
+
# Pre-transformation constraints
|
| 95 |
+
#
|
| 96 |
+
"repeat": "textattack.constraints.pre_transformation.RepeatModification",
|
| 97 |
+
"stopword": "textattack.constraints.pre_transformation.StopwordModification",
|
| 98 |
+
"max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
SEARCH_METHOD_CLASS_NAMES = {
|
| 103 |
+
"beam-search": "textattack.search_methods.BeamSearch",
|
| 104 |
+
"greedy": "textattack.search_methods.GreedySearch",
|
| 105 |
+
"ga-word": "textattack.search_methods.GeneticAlgorithm",
|
| 106 |
+
"greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
|
| 107 |
+
"pso": "textattack.search_methods.ParticleSwarmOptimization",
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
GOAL_FUNCTION_CLASS_NAMES = {
|
| 112 |
+
#
|
| 113 |
+
# Classification goal functions
|
| 114 |
+
#
|
| 115 |
+
"targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
|
| 116 |
+
"untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
|
| 117 |
+
"input-reduction": "textattack.goal_functions.classification.InputReduction",
|
| 118 |
+
#
|
| 119 |
+
# Text goal functions
|
| 120 |
+
#
|
| 121 |
+
"minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
|
| 122 |
+
"non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
|
| 123 |
+
"text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class AttackArgs:
|
| 129 |
+
"""Attack arguments to be passed to :class:`~textattack.Attacker`.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
|
| 133 |
+
The number of examples to attack. :obj:`-1` for entire dataset.
|
| 134 |
+
num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
|
| 135 |
+
The number of successful adversarial examples we want. This is different from :obj:`num_examples`
|
| 136 |
+
as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
|
| 137 |
+
until we have `N` successful cases.
|
| 138 |
+
|
| 139 |
+
.. note::
|
| 140 |
+
If set, this argument overrides `num_examples` argument.
|
| 141 |
+
num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
|
| 142 |
+
The offset index to start at in the dataset.
|
| 143 |
+
attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 144 |
+
Whether to run attack until total of `N` examples have been attacked (and not skipped).
|
| 145 |
+
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 146 |
+
If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
|
| 147 |
+
the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
|
| 148 |
+
:obj:`shuffle` can now be used with checkpoint saving.
|
| 149 |
+
query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
|
| 150 |
+
The maximum number of model queries allowed per example attacked.
|
| 151 |
+
If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
|
| 152 |
+
|
| 153 |
+
.. note::
|
| 154 |
+
Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
|
| 155 |
+
checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
|
| 156 |
+
If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
|
| 157 |
+
checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
|
| 158 |
+
The directory to save checkpoint files.
|
| 159 |
+
random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
|
| 160 |
+
Random seed for reproducibility.
|
| 161 |
+
parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
|
| 162 |
+
If :obj:`True`, run attack using multiple CPUs/GPUs.
|
| 163 |
+
num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
|
| 164 |
+
Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
|
| 165 |
+
then 2 processes will be running in each GPU.
|
| 166 |
+
log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 167 |
+
If set, save attack logs as a `.txt` file to the directory specified by this argument.
|
| 168 |
+
If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
|
| 169 |
+
log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
If set, save attack logs as a CSV file to the directory specified by this argument.
|
| 171 |
+
If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
|
| 172 |
+
csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
|
| 173 |
+
Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
|
| 174 |
+
:obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
|
| 175 |
+
log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
|
| 176 |
+
If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
|
| 177 |
+
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
| 178 |
+
three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
|
| 179 |
+
log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
|
| 180 |
+
If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
|
| 181 |
+
Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
|
| 182 |
+
key and its corresponding value: :obj:`"project"`.
|
| 183 |
+
disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 184 |
+
Disable displaying individual attack results to stdout.
|
| 185 |
+
silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 186 |
+
Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
|
| 187 |
+
enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 188 |
+
Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
num_examples: int = 10
|
| 192 |
+
num_successful_examples: int = None
|
| 193 |
+
num_examples_offset: int = 0
|
| 194 |
+
attack_n: bool = False
|
| 195 |
+
shuffle: bool = False
|
| 196 |
+
query_budget: int = None
|
| 197 |
+
checkpoint_interval: int = None
|
| 198 |
+
checkpoint_dir: str = "checkpoints"
|
| 199 |
+
random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
|
| 200 |
+
parallel: bool = False
|
| 201 |
+
num_workers_per_device: int = 1
|
| 202 |
+
log_to_txt: str = None
|
| 203 |
+
log_to_csv: str = None
|
| 204 |
+
log_summary_to_json: str = None
|
| 205 |
+
csv_coloring_style: str = "file"
|
| 206 |
+
log_to_visdom: dict = None
|
| 207 |
+
log_to_wandb: dict = None
|
| 208 |
+
disable_stdout: bool = False
|
| 209 |
+
silent: bool = False
|
| 210 |
+
enable_advance_metrics: bool = False
|
| 211 |
+
metrics: Optional[Dict] = None
|
| 212 |
+
|
| 213 |
+
def __post_init__(self):
|
| 214 |
+
if self.num_successful_examples:
|
| 215 |
+
self.num_examples = None
|
| 216 |
+
if self.num_examples:
|
| 217 |
+
assert (
|
| 218 |
+
self.num_examples >= 0 or self.num_examples == -1
|
| 219 |
+
), "`num_examples` must be greater than or equal to 0 or equal to -1."
|
| 220 |
+
if self.num_successful_examples:
|
| 221 |
+
assert (
|
| 222 |
+
self.num_successful_examples >= 0
|
| 223 |
+
), "`num_examples` must be greater than or equal to 0."
|
| 224 |
+
|
| 225 |
+
if self.query_budget:
|
| 226 |
+
assert self.query_budget > 0, "`query_budget` must be greater than 0."
|
| 227 |
+
|
| 228 |
+
if self.checkpoint_interval:
|
| 229 |
+
assert (
|
| 230 |
+
self.checkpoint_interval > 0
|
| 231 |
+
), "`checkpoint_interval` must be greater than 0."
|
| 232 |
+
|
| 233 |
+
assert (
|
| 234 |
+
self.num_workers_per_device > 0
|
| 235 |
+
), "`num_workers_per_device` must be greater than 0."
|
| 236 |
+
|
| 237 |
+
@classmethod
|
| 238 |
+
def _add_parser_args(cls, parser):
|
| 239 |
+
"""Add listed args to command line parser."""
|
| 240 |
+
default_obj = cls()
|
| 241 |
+
num_ex_group = parser.add_mutually_exclusive_group(required=False)
|
| 242 |
+
num_ex_group.add_argument(
|
| 243 |
+
"--num-examples",
|
| 244 |
+
"-n",
|
| 245 |
+
type=int,
|
| 246 |
+
default=default_obj.num_examples,
|
| 247 |
+
help="The number of examples to process, -1 for entire dataset.",
|
| 248 |
+
)
|
| 249 |
+
num_ex_group.add_argument(
|
| 250 |
+
"--num-successful-examples",
|
| 251 |
+
type=int,
|
| 252 |
+
default=default_obj.num_successful_examples,
|
| 253 |
+
help="The number of successful adversarial examples we want.",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--num-examples-offset",
|
| 257 |
+
"-o",
|
| 258 |
+
type=int,
|
| 259 |
+
required=False,
|
| 260 |
+
default=default_obj.num_examples_offset,
|
| 261 |
+
help="The offset to start at in the dataset.",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--query-budget",
|
| 265 |
+
"-q",
|
| 266 |
+
type=int,
|
| 267 |
+
default=default_obj.query_budget,
|
| 268 |
+
help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--shuffle",
|
| 272 |
+
action="store_true",
|
| 273 |
+
default=default_obj.shuffle,
|
| 274 |
+
help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument(
|
| 277 |
+
"--attack-n",
|
| 278 |
+
action="store_true",
|
| 279 |
+
default=default_obj.attack_n,
|
| 280 |
+
help="Whether to run attack until `n` examples have been attacked (not skipped).",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--checkpoint-dir",
|
| 284 |
+
required=False,
|
| 285 |
+
type=str,
|
| 286 |
+
default=default_obj.checkpoint_dir,
|
| 287 |
+
help="The directory to save checkpoint files.",
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--checkpoint-interval",
|
| 291 |
+
required=False,
|
| 292 |
+
type=int,
|
| 293 |
+
default=default_obj.checkpoint_interval,
|
| 294 |
+
help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument(
|
| 297 |
+
"--random-seed",
|
| 298 |
+
default=default_obj.random_seed,
|
| 299 |
+
type=int,
|
| 300 |
+
help="Random seed for reproducibility.",
|
| 301 |
+
)
|
| 302 |
+
parser.add_argument(
|
| 303 |
+
"--parallel",
|
| 304 |
+
action="store_true",
|
| 305 |
+
default=default_obj.parallel,
|
| 306 |
+
help="Run attack using multiple GPUs.",
|
| 307 |
+
)
|
| 308 |
+
parser.add_argument(
|
| 309 |
+
"--num-workers-per-device",
|
| 310 |
+
default=default_obj.num_workers_per_device,
|
| 311 |
+
type=int,
|
| 312 |
+
help="Number of worker processes to run per device.",
|
| 313 |
+
)
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--log-to-txt",
|
| 316 |
+
nargs="?",
|
| 317 |
+
default=default_obj.log_to_txt,
|
| 318 |
+
const="",
|
| 319 |
+
type=str,
|
| 320 |
+
help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
|
| 321 |
+
"If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--log-to-csv",
|
| 325 |
+
nargs="?",
|
| 326 |
+
default=default_obj.log_to_csv,
|
| 327 |
+
const="",
|
| 328 |
+
type=str,
|
| 329 |
+
help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
|
| 330 |
+
"If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
|
| 331 |
+
)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--log-summary-to-json",
|
| 334 |
+
nargs="?",
|
| 335 |
+
default=default_obj.log_summary_to_json,
|
| 336 |
+
const="",
|
| 337 |
+
type=str,
|
| 338 |
+
help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
|
| 339 |
+
"If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--csv-coloring-style",
|
| 343 |
+
default=default_obj.csv_coloring_style,
|
| 344 |
+
type=str,
|
| 345 |
+
help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
|
| 346 |
+
'"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--log-to-visdom",
|
| 350 |
+
nargs="?",
|
| 351 |
+
default=None,
|
| 352 |
+
const='{"env": "main", "port": 8097, "hostname": "localhost"}',
|
| 353 |
+
type=json.loads,
|
| 354 |
+
help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
|
| 355 |
+
'three keys and their corresponding values: `"env", "port", "hostname"`. '
|
| 356 |
+
'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
|
| 357 |
+
)
|
| 358 |
+
parser.add_argument(
|
| 359 |
+
"--log-to-wandb",
|
| 360 |
+
nargs="?",
|
| 361 |
+
default=None,
|
| 362 |
+
const='{"project": "textattack"}',
|
| 363 |
+
type=json.loads,
|
| 364 |
+
help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
|
| 365 |
+
'key and its corresponding value: `"project"`. '
|
| 366 |
+
'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--disable-stdout",
|
| 370 |
+
action="store_true",
|
| 371 |
+
default=default_obj.disable_stdout,
|
| 372 |
+
help="Disable logging attack results to stdout",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--silent",
|
| 376 |
+
action="store_true",
|
| 377 |
+
default=default_obj.silent,
|
| 378 |
+
help="Disable all logging",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--enable-advance-metrics",
|
| 382 |
+
action="store_true",
|
| 383 |
+
default=default_obj.enable_advance_metrics,
|
| 384 |
+
help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
return parser
|
| 388 |
+
|
| 389 |
+
@classmethod
|
| 390 |
+
def create_loggers_from_args(cls, args):
|
| 391 |
+
"""Creates AttackLogManager from an AttackArgs object."""
|
| 392 |
+
assert isinstance(
|
| 393 |
+
args, cls
|
| 394 |
+
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
|
| 395 |
+
|
| 396 |
+
# Create logger
|
| 397 |
+
attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
|
| 398 |
+
|
| 399 |
+
# Get current time for file naming
|
| 400 |
+
timestamp = time.strftime("%Y-%m-%d-%H-%M")
|
| 401 |
+
|
| 402 |
+
# if '--log-to-txt' specified with arguments
|
| 403 |
+
if args.log_to_txt is not None:
|
| 404 |
+
if args.log_to_txt.lower().endswith(".txt"):
|
| 405 |
+
txt_file_path = args.log_to_txt
|
| 406 |
+
else:
|
| 407 |
+
txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")
|
| 408 |
+
|
| 409 |
+
dir_path = os.path.dirname(txt_file_path)
|
| 410 |
+
dir_path = dir_path if dir_path else "."
|
| 411 |
+
if not os.path.exists(dir_path):
|
| 412 |
+
os.makedirs(os.path.dirname(txt_file_path))
|
| 413 |
+
|
| 414 |
+
color_method = "file"
|
| 415 |
+
attack_log_manager.add_output_file(txt_file_path, color_method)
|
| 416 |
+
|
| 417 |
+
# if '--log-to-csv' specified with arguments
|
| 418 |
+
if args.log_to_csv is not None:
|
| 419 |
+
if args.log_to_csv.lower().endswith(".csv"):
|
| 420 |
+
csv_file_path = args.log_to_csv
|
| 421 |
+
else:
|
| 422 |
+
csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")
|
| 423 |
+
|
| 424 |
+
dir_path = os.path.dirname(csv_file_path)
|
| 425 |
+
dir_path = dir_path if dir_path else "."
|
| 426 |
+
if not os.path.exists(dir_path):
|
| 427 |
+
os.makedirs(dir_path)
|
| 428 |
+
|
| 429 |
+
color_method = (
|
| 430 |
+
None if args.csv_coloring_style == "plain" else args.csv_coloring_style
|
| 431 |
+
)
|
| 432 |
+
attack_log_manager.add_output_csv(csv_file_path, color_method)
|
| 433 |
+
|
| 434 |
+
# if '--log-summary-to-json' specified with arguments
|
| 435 |
+
if args.log_summary_to_json is not None:
|
| 436 |
+
if args.log_summary_to_json.lower().endswith(".json"):
|
| 437 |
+
summary_json_file_path = args.log_summary_to_json
|
| 438 |
+
else:
|
| 439 |
+
summary_json_file_path = os.path.join(
|
| 440 |
+
args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
dir_path = os.path.dirname(summary_json_file_path)
|
| 444 |
+
dir_path = dir_path if dir_path else "."
|
| 445 |
+
if not os.path.exists(dir_path):
|
| 446 |
+
os.makedirs(os.path.dirname(summary_json_file_path))
|
| 447 |
+
|
| 448 |
+
attack_log_manager.add_output_summary_json(summary_json_file_path)
|
| 449 |
+
|
| 450 |
+
# Visdom
|
| 451 |
+
if args.log_to_visdom is not None:
|
| 452 |
+
attack_log_manager.enable_visdom(**args.log_to_visdom)
|
| 453 |
+
|
| 454 |
+
# Weights & Biases
|
| 455 |
+
if args.log_to_wandb is not None:
|
| 456 |
+
attack_log_manager.enable_wandb(**args.log_to_wandb)
|
| 457 |
+
|
| 458 |
+
# Stdout
|
| 459 |
+
if not args.disable_stdout and not sys.stdout.isatty():
|
| 460 |
+
attack_log_manager.disable_color()
|
| 461 |
+
elif not args.disable_stdout:
|
| 462 |
+
attack_log_manager.enable_stdout()
|
| 463 |
+
|
| 464 |
+
return attack_log_manager
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@dataclass
|
| 468 |
+
class _CommandLineAttackArgs:
|
| 469 |
+
"""Attack args for command line execution. This requires more arguments to
|
| 470 |
+
create ``Attack`` object as specified.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
|
| 474 |
+
Name of transformation to use.
|
| 475 |
+
constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
|
| 476 |
+
List of names of constraints to use.
|
| 477 |
+
goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
|
| 478 |
+
Name of goal function to use.
|
| 479 |
+
search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
|
| 480 |
+
Name of search method to use.
|
| 481 |
+
attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 482 |
+
Name of attack recipe to use.
|
| 483 |
+
.. note::
|
| 484 |
+
Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
|
| 485 |
+
attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
|
| 486 |
+
Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
|
| 487 |
+
.. note::
|
| 488 |
+
If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
|
| 489 |
+
interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 490 |
+
If `True`, carry attack in interactive mode.
|
| 491 |
+
parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 492 |
+
If `True`, attack in parallel.
|
| 493 |
+
model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
|
| 494 |
+
The batch size for making queries to the victim model.
|
| 495 |
+
model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
| 496 |
+
The maximum number of items to keep in the model results cache at once.
|
| 497 |
+
constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
|
| 498 |
+
The maximum number of items to keep in the constraints cache at once.
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
transformation: str = "word-swap-embedding"
|
| 502 |
+
constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
|
| 503 |
+
goal_function: str = "untargeted-classification"
|
| 504 |
+
search_method: str = "greedy-word-wir"
|
| 505 |
+
attack_recipe: str = None
|
| 506 |
+
attack_from_file: str = None
|
| 507 |
+
interactive: bool = False
|
| 508 |
+
parallel: bool = False
|
| 509 |
+
model_batch_size: int = 32
|
| 510 |
+
model_cache_size: int = 2**18
|
| 511 |
+
constraint_cache_size: int = 2**18
|
| 512 |
+
|
| 513 |
+
@classmethod
|
| 514 |
+
def _add_parser_args(cls, parser):
|
| 515 |
+
"""Add listed args to command line parser."""
|
| 516 |
+
default_obj = cls()
|
| 517 |
+
transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
|
| 518 |
+
WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--transformation",
|
| 522 |
+
type=str,
|
| 523 |
+
required=False,
|
| 524 |
+
default=default_obj.transformation,
|
| 525 |
+
help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
| 526 |
+
+ str(transformation_names),
|
| 527 |
+
)
|
| 528 |
+
parser.add_argument(
|
| 529 |
+
"--constraints",
|
| 530 |
+
type=str,
|
| 531 |
+
required=False,
|
| 532 |
+
nargs="*",
|
| 533 |
+
default=default_obj.constraints,
|
| 534 |
+
help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
|
| 535 |
+
+ str(CONSTRAINT_CLASS_NAMES.keys()),
|
| 536 |
+
)
|
| 537 |
+
goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
|
| 538 |
+
parser.add_argument(
|
| 539 |
+
"--goal-function",
|
| 540 |
+
"-g",
|
| 541 |
+
default=default_obj.goal_function,
|
| 542 |
+
help=f"The goal function to use. choices: {goal_function_choices}",
|
| 543 |
+
)
|
| 544 |
+
attack_group = parser.add_mutually_exclusive_group(required=False)
|
| 545 |
+
search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
|
| 546 |
+
attack_group.add_argument(
|
| 547 |
+
"--search-method",
|
| 548 |
+
"--search",
|
| 549 |
+
"-s",
|
| 550 |
+
type=str,
|
| 551 |
+
required=False,
|
| 552 |
+
default=default_obj.search_method,
|
| 553 |
+
help=f"The search method to use. choices: {search_choices}",
|
| 554 |
+
)
|
| 555 |
+
attack_group.add_argument(
|
| 556 |
+
"--attack-recipe",
|
| 557 |
+
"--recipe",
|
| 558 |
+
"-r",
|
| 559 |
+
type=str,
|
| 560 |
+
required=False,
|
| 561 |
+
default=default_obj.attack_recipe,
|
| 562 |
+
help="full attack recipe (overrides provided goal function, transformation & constraints)",
|
| 563 |
+
choices=ATTACK_RECIPE_NAMES.keys(),
|
| 564 |
+
)
|
| 565 |
+
attack_group.add_argument(
|
| 566 |
+
"--attack-from-file",
|
| 567 |
+
type=str,
|
| 568 |
+
required=False,
|
| 569 |
+
default=default_obj.attack_from_file,
|
| 570 |
+
help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
|
| 571 |
+
)
|
| 572 |
+
parser.add_argument(
|
| 573 |
+
"--interactive",
|
| 574 |
+
action="store_true",
|
| 575 |
+
default=default_obj.interactive,
|
| 576 |
+
help="Whether to run attacks interactively.",
|
| 577 |
+
)
|
| 578 |
+
parser.add_argument(
|
| 579 |
+
"--model-batch-size",
|
| 580 |
+
type=int,
|
| 581 |
+
default=default_obj.model_batch_size,
|
| 582 |
+
help="The batch size for making calls to the model.",
|
| 583 |
+
)
|
| 584 |
+
parser.add_argument(
|
| 585 |
+
"--model-cache-size",
|
| 586 |
+
type=int,
|
| 587 |
+
default=default_obj.model_cache_size,
|
| 588 |
+
help="The maximum number of items to keep in the model results cache at once.",
|
| 589 |
+
)
|
| 590 |
+
parser.add_argument(
|
| 591 |
+
"--constraint-cache-size",
|
| 592 |
+
type=int,
|
| 593 |
+
default=default_obj.constraint_cache_size,
|
| 594 |
+
help="The maximum number of items to keep in the constraints cache at once.",
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
return parser
|
| 598 |
+
|
| 599 |
+
@classmethod
|
| 600 |
+
def _create_transformation_from_args(cls, args, model_wrapper):
|
| 601 |
+
"""Create `Transformation` based on provided `args` and
|
| 602 |
+
`model_wrapper`."""
|
| 603 |
+
|
| 604 |
+
transformation_name = args.transformation
|
| 605 |
+
if ARGS_SPLIT_TOKEN in transformation_name:
|
| 606 |
+
transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
|
| 607 |
+
|
| 608 |
+
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 609 |
+
transformation = eval(
|
| 610 |
+
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
|
| 611 |
+
)
|
| 612 |
+
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 613 |
+
transformation = eval(
|
| 614 |
+
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
|
| 615 |
+
)
|
| 616 |
+
else:
|
| 617 |
+
raise ValueError(
|
| 618 |
+
f"Error: unsupported transformation {transformation_name}"
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 622 |
+
transformation = eval(
|
| 623 |
+
f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
|
| 624 |
+
)
|
| 625 |
+
elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
|
| 626 |
+
transformation = eval(
|
| 627 |
+
f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
|
| 628 |
+
)
|
| 629 |
+
else:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
f"Error: unsupported transformation {transformation_name}"
|
| 632 |
+
)
|
| 633 |
+
return transformation
|
| 634 |
+
|
| 635 |
+
@classmethod
|
| 636 |
+
def _create_goal_function_from_args(cls, args, model_wrapper):
|
| 637 |
+
"""Create `GoalFunction` based on provided `args` and
|
| 638 |
+
`model_wrapper`."""
|
| 639 |
+
|
| 640 |
+
goal_function = args.goal_function
|
| 641 |
+
if ARGS_SPLIT_TOKEN in goal_function:
|
| 642 |
+
goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
|
| 643 |
+
if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
|
| 644 |
+
raise ValueError(
|
| 645 |
+
f"Error: unsupported goal_function {goal_function_name}"
|
| 646 |
+
)
|
| 647 |
+
goal_function = eval(
|
| 648 |
+
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
|
| 649 |
+
)
|
| 650 |
+
elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
|
| 651 |
+
goal_function = eval(
|
| 652 |
+
f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
|
| 653 |
+
)
|
| 654 |
+
else:
|
| 655 |
+
raise ValueError(f"Error: unsupported goal_function {goal_function}")
|
| 656 |
+
if args.query_budget:
|
| 657 |
+
goal_function.query_budget = args.query_budget
|
| 658 |
+
goal_function.model_cache_size = args.model_cache_size
|
| 659 |
+
goal_function.batch_size = args.model_batch_size
|
| 660 |
+
return goal_function
|
| 661 |
+
|
| 662 |
+
@classmethod
|
| 663 |
+
def _create_constraints_from_args(cls, args):
|
| 664 |
+
"""Create list of `Constraints` based on provided `args`."""
|
| 665 |
+
|
| 666 |
+
if not args.constraints:
|
| 667 |
+
return []
|
| 668 |
+
|
| 669 |
+
_constraints = []
|
| 670 |
+
for constraint in args.constraints:
|
| 671 |
+
if ARGS_SPLIT_TOKEN in constraint:
|
| 672 |
+
constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
|
| 673 |
+
if constraint_name not in CONSTRAINT_CLASS_NAMES:
|
| 674 |
+
raise ValueError(f"Error: unsupported constraint {constraint_name}")
|
| 675 |
+
_constraints.append(
|
| 676 |
+
eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
|
| 677 |
+
)
|
| 678 |
+
elif constraint in CONSTRAINT_CLASS_NAMES:
|
| 679 |
+
_constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
|
| 680 |
+
else:
|
| 681 |
+
raise ValueError(f"Error: unsupported constraint {constraint}")
|
| 682 |
+
|
| 683 |
+
return _constraints
|
| 684 |
+
|
| 685 |
+
@classmethod
|
| 686 |
+
def _create_attack_from_args(cls, args, model_wrapper):
|
| 687 |
+
"""Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
|
| 688 |
+
``Attack`` object."""
|
| 689 |
+
|
| 690 |
+
assert isinstance(
|
| 691 |
+
args, cls
|
| 692 |
+
), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
|
| 693 |
+
|
| 694 |
+
if args.attack_recipe:
|
| 695 |
+
if ARGS_SPLIT_TOKEN in args.attack_recipe:
|
| 696 |
+
recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
|
| 697 |
+
if recipe_name not in ATTACK_RECIPE_NAMES:
|
| 698 |
+
raise ValueError(f"Error: unsupported recipe {recipe_name}")
|
| 699 |
+
recipe = eval(
|
| 700 |
+
f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
|
| 701 |
+
)
|
| 702 |
+
elif args.attack_recipe in ATTACK_RECIPE_NAMES:
|
| 703 |
+
recipe = eval(
|
| 704 |
+
f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
|
| 705 |
+
)
|
| 706 |
+
else:
|
| 707 |
+
raise ValueError(f"Invalid recipe {args.attack_recipe}")
|
| 708 |
+
if args.query_budget:
|
| 709 |
+
recipe.goal_function.query_budget = args.query_budget
|
| 710 |
+
recipe.goal_function.model_cache_size = args.model_cache_size
|
| 711 |
+
recipe.constraint_cache_size = args.constraint_cache_size
|
| 712 |
+
return recipe
|
| 713 |
+
elif args.attack_from_file:
|
| 714 |
+
if ARGS_SPLIT_TOKEN in args.attack_from_file:
|
| 715 |
+
attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
|
| 716 |
+
else:
|
| 717 |
+
attack_file, attack_name = args.attack_from_file, "attack"
|
| 718 |
+
attack_module = load_module_from_file(attack_file)
|
| 719 |
+
if not hasattr(attack_module, attack_name):
|
| 720 |
+
raise ValueError(
|
| 721 |
+
f"Loaded `{attack_file}` but could not find `{attack_name}`."
|
| 722 |
+
)
|
| 723 |
+
attack_func = getattr(attack_module, attack_name)
|
| 724 |
+
return attack_func(model_wrapper)
|
| 725 |
+
else:
|
| 726 |
+
goal_function = cls._create_goal_function_from_args(args, model_wrapper)
|
| 727 |
+
transformation = cls._create_transformation_from_args(args, model_wrapper)
|
| 728 |
+
constraints = cls._create_constraints_from_args(args)
|
| 729 |
+
if ARGS_SPLIT_TOKEN in args.search_method:
|
| 730 |
+
search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
|
| 731 |
+
if search_name not in SEARCH_METHOD_CLASS_NAMES:
|
| 732 |
+
raise ValueError(f"Error: unsupported search {search_name}")
|
| 733 |
+
search_method = eval(
|
| 734 |
+
f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
|
| 735 |
+
)
|
| 736 |
+
elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
|
| 737 |
+
search_method = eval(
|
| 738 |
+
f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
|
| 739 |
+
)
|
| 740 |
+
else:
|
| 741 |
+
raise ValueError(f"Error: unsupported attack {args.search_method}")
|
| 742 |
+
|
| 743 |
+
return Attack(
|
| 744 |
+
goal_function,
|
| 745 |
+
constraints,
|
| 746 |
+
transformation,
|
| 747 |
+
search_method,
|
| 748 |
+
constraint_cache_size=args.constraint_cache_size,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
|
| 753 |
+
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
| 754 |
+
@dataclass
|
| 755 |
+
class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
|
| 756 |
+
@classmethod
|
| 757 |
+
def _add_parser_args(cls, parser):
|
| 758 |
+
"""Add listed args to command line parser."""
|
| 759 |
+
parser = ModelArgs._add_parser_args(parser)
|
| 760 |
+
parser = DatasetArgs._add_parser_args(parser)
|
| 761 |
+
parser = _CommandLineAttackArgs._add_parser_args(parser)
|
| 762 |
+
parser = AttackArgs._add_parser_args(parser)
|
| 763 |
+
return parser
|
textattack/attack_recipes/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""".. _attack_recipes:
|
| 2 |
+
|
| 3 |
+
Attack Recipes Package:
|
| 4 |
+
========================
|
| 5 |
+
|
| 6 |
+
We provide a number of pre-built attack recipes, which correspond to attacks from the literature. To run an attack recipe from the command line, run::
|
| 7 |
+
|
| 8 |
+
textattack attack --recipe [recipe_name]
|
| 9 |
+
|
| 10 |
+
To initialize an attack in Python script, use::
|
| 11 |
+
|
| 12 |
+
<recipe name>.build(model_wrapper)
|
| 13 |
+
|
| 14 |
+
For example, ``attack = InputReductionFeng2018.build(model)`` creates `attack`, an object of type ``Attack`` with the goal function, transformation, constraints, and search method specified in that paper. This object can then be used just like any other attack; for example, by calling ``attack.attack_dataset``.
|
| 15 |
+
|
| 16 |
+
TextAttack supports the following attack recipes (each recipe's documentation contains a link to the corresponding paper):
|
| 17 |
+
|
| 18 |
+
.. contents:: :local:
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from .attack_recipe import AttackRecipe
|
| 22 |
+
|
| 23 |
+
from .a2t_yoo_2021 import A2TYoo2021
|
| 24 |
+
from .bae_garg_2019 import BAEGarg2019
|
| 25 |
+
from .bert_attack_li_2020 import BERTAttackLi2020
|
| 26 |
+
from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
|
| 27 |
+
from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
|
| 28 |
+
from .deepwordbug_gao_2018 import DeepWordBugGao2018
|
| 29 |
+
from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
|
| 30 |
+
from .input_reduction_feng_2018 import InputReductionFeng2018
|
| 31 |
+
from .kuleshov_2017 import Kuleshov2017
|
| 32 |
+
from .morpheus_tan_2020 import MorpheusTan2020
|
| 33 |
+
from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
|
| 34 |
+
from .textbugger_li_2018 import TextBuggerLi2018
|
| 35 |
+
from .textfooler_jin_2019 import TextFoolerJin2019
|
| 36 |
+
from .pwws_ren_2019 import PWWSRen2019
|
| 37 |
+
from .iga_wang_2019 import IGAWang2019
|
| 38 |
+
from .pruthi_2019 import Pruthi2019
|
| 39 |
+
from .pso_zang_2020 import PSOZang2020
|
| 40 |
+
from .checklist_ribeiro_2020 import CheckList2020
|
| 41 |
+
from .clare_li_2020 import CLARE2020
|
| 42 |
+
from .french_recipe import FrenchRecipe
|
| 43 |
+
from .spanish_recipe import SpanishRecipe
|
textattack/attack_recipes/a2t_yoo_2021.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A2T (A2T: Attack for Adversarial Training Recipe)
|
| 3 |
+
==================================================
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from textattack import Attack
|
| 8 |
+
from textattack.constraints.grammaticality import PartOfSpeech
|
| 9 |
+
from textattack.constraints.pre_transformation import (
|
| 10 |
+
InputColumnModification,
|
| 11 |
+
MaxModificationRate,
|
| 12 |
+
RepeatModification,
|
| 13 |
+
StopwordModification,
|
| 14 |
+
)
|
| 15 |
+
from textattack.constraints.semantics import WordEmbeddingDistance
|
| 16 |
+
from textattack.constraints.semantics.sentence_encoders import BERT
|
| 17 |
+
from textattack.goal_functions import UntargetedClassification
|
| 18 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
| 19 |
+
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
|
| 20 |
+
|
| 21 |
+
from .attack_recipe import AttackRecipe
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class A2TYoo2021(AttackRecipe):
|
| 25 |
+
"""Towards Improving Adversarial Training of NLP Models.
|
| 26 |
+
|
| 27 |
+
(Yoo et al., 2021)
|
| 28 |
+
|
| 29 |
+
https://arxiv.org/abs/2109.00544
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def build(model_wrapper, mlm=False):
|
| 34 |
+
"""Build attack recipe.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
| 38 |
+
Model wrapper containing both the model and the tokenizer.
|
| 39 |
+
mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 40 |
+
If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
:class:`~textattack.Attack`: A2T attack.
|
| 44 |
+
"""
|
| 45 |
+
constraints = [RepeatModification(), StopwordModification()]
|
| 46 |
+
input_column_modification = InputColumnModification(
|
| 47 |
+
["premise", "hypothesis"], {"premise"}
|
| 48 |
+
)
|
| 49 |
+
constraints.append(input_column_modification)
|
| 50 |
+
constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
|
| 51 |
+
constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
|
| 52 |
+
sent_encoder = BERT(
|
| 53 |
+
model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
|
| 54 |
+
)
|
| 55 |
+
constraints.append(sent_encoder)
|
| 56 |
+
|
| 57 |
+
if mlm:
|
| 58 |
+
transformation = transformation = WordSwapMaskedLM(
|
| 59 |
+
method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
transformation = WordSwapEmbedding(max_candidates=20)
|
| 63 |
+
constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
|
| 64 |
+
|
| 65 |
+
#
|
| 66 |
+
# Goal is untargeted classification
|
| 67 |
+
#
|
| 68 |
+
goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
|
| 69 |
+
#
|
| 70 |
+
# Greedily swap words with "Word Importance Ranking".
|
| 71 |
+
#
|
| 72 |
+
search_method = GreedyWordSwapWIR(wir_method="gradient")
|
| 73 |
+
|
| 74 |
+
return Attack(goal_function, constraints, transformation, search_method)
|
textattack/attack_recipes/attack_recipe.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Attack Recipe Class
|
| 3 |
+
========================
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
from textattack import Attack
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AttackRecipe(Attack, ABC):
|
| 13 |
+
"""A recipe for building an NLP adversarial attack from the literature."""
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def build(model_wrapper, **kwargs):
|
| 18 |
+
"""Creates pre-built :class:`~textattack.Attack` that correspond to
|
| 19 |
+
attacks from the literature.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
|
| 23 |
+
:class:`~textattack.models.wrappers.ModelWrapper` that contains the victim model and tokenizer.
|
| 24 |
+
This is passed to :class:`~textattack.goal_functions.GoalFunction` when constructing the attack.
|
| 25 |
+
kwargs:
|
| 26 |
+
Additional keyword arguments.
|
| 27 |
+
Returns:
|
| 28 |
+
:class:`~textattack.Attack`
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError()
|
textattack/attack_recipes/bae_garg_2019.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BAE (BAE: BERT-Based Adversarial Examples)
|
| 3 |
+
============================================
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
from textattack.constraints.grammaticality import PartOfSpeech
|
| 7 |
+
from textattack.constraints.pre_transformation import (
|
| 8 |
+
RepeatModification,
|
| 9 |
+
StopwordModification,
|
| 10 |
+
)
|
| 11 |
+
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
|
| 12 |
+
from textattack.goal_functions import UntargetedClassification
|
| 13 |
+
from textattack.search_methods import GreedyWordSwapWIR
|
| 14 |
+
from textattack.transformations import WordSwapMaskedLM
|
| 15 |
+
|
| 16 |
+
from .attack_recipe import AttackRecipe
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BAEGarg2019(AttackRecipe):
|
| 20 |
+
"""Siddhant Garg and Goutham Ramakrishnan, 2019.
|
| 21 |
+
|
| 22 |
+
BAE: BERT-based Adversarial Examples for Text Classification.
|
| 23 |
+
|
| 24 |
+
https://arxiv.org/pdf/2004.01970
|
| 25 |
+
|
| 26 |
+
This is "attack mode" 1 from the paper, BAE-R, word replacement.
|
| 27 |
+
|
| 28 |
+
We present 4 attack modes for BAE based on the
|
| 29 |
+
R and I operations, where for each token t in S:
|
| 30 |
+
• BAE-R: Replace token t (See Algorithm 1)
|
| 31 |
+
• BAE-I: Insert a token to the left or right of t
|
| 32 |
+
• BAE-R/I: Either replace token t or insert a
|
| 33 |
+
token to the left or right of t
|
| 34 |
+
• BAE-R+I: First replace token t, then insert a
|
| 35 |
+
token to the left or right of t
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def build(model_wrapper):
|
| 40 |
+
# "In this paper, we present a simple yet novel technique: BAE (BERT-based
|
| 41 |
+
# Adversarial Examples), which uses a language model (LM) for token
|
| 42 |
+
# replacement to best fit the overall context. We perturb an input sentence
|
| 43 |
+
# by either replacing a token or inserting a new token in the sentence, by
|
| 44 |
+
# means of masking a part of the input and using a LM to fill in the mask."
|
| 45 |
+
#
|
| 46 |
+
# We only consider the top K=50 synonyms from the MLM predictions.
|
| 47 |
+
#
|
| 48 |
+
# [from email correspondance with the author]
|
| 49 |
+
# "When choosing the top-K candidates from the BERT masked LM, we filter out
|
| 50 |
+
# the sub-words and only retain the whole words (by checking if they are
|
| 51 |
+
# present in the GloVE vocabulary)"
|
| 52 |
+
#
|
| 53 |
+
transformation = WordSwapMaskedLM(
|
| 54 |
+
method="bae", max_candidates=50, min_confidence=0.0
|
| 55 |
+
)
|
| 56 |
+
#
|
| 57 |
+
# Don't modify the same word twice or stopwords.
|
| 58 |
+
#
|
| 59 |
+
constraints = [RepeatModification(), StopwordModification()]
|
| 60 |
+
|
| 61 |
+
# For the R operations we add an additional check for
|
| 62 |
+
# grammatical correctness of the generated adversarial example by filtering
|
| 63 |
+
# out predicted tokens that do not form the same part of speech (POS) as the
|
| 64 |
+
# original token t_i in the sentence.
|
| 65 |
+
constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
|
| 66 |
+
|
| 67 |
+
# "To ensure semantic similarity on introducing perturbations in the input
|
| 68 |
+
# text, we filter the set of top-K masked tokens (K is a pre-defined
|
| 69 |
+
# constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
|
| 70 |
+
# (Cer et al., 2018)-based sentence similarity scorer."
|
| 71 |
+
#
|
| 72 |
+
# "[We] set a threshold of 0.8 for the cosine similarity between USE-based
|
| 73 |
+
# embeddings of the adversarial and input text."
|
| 74 |
+
#
|
| 75 |
+
# [from email correspondence with the author]
|
| 76 |
+
# "For a fair comparison of the benefits of using a BERT-MLM in our paper,
|
| 77 |
+
# we retained the majority of TextFooler's specifications. Thus we:
|
| 78 |
+
# 1. Use the USE for comparison within a window of size 15 around the word
|
| 79 |
+
# being replaced/inserted.
|
| 80 |
+
# 2. Set the similarity score threshold to 0.1 for inputs shorter than the
|
| 81 |
+
# window size (this translates roughly to almost always accepting the new text).
|
| 82 |
+
# 3. Perform the USE similarity thresholding of 0.8 with respect to the text
|
| 83 |
+
# just before the replacement/insertion and not the original text (For
|
| 84 |
+
# example: at the 3rd R/I operation, we compute the USE score on a window
|
| 85 |
+
# of size 15 of the text obtained after the first 2 R/I operations and not
|
| 86 |
+
# the original text).
|
| 87 |
+
# ...
|
| 88 |
+
# To address point (3) from above, compare the USE with the original text
|
| 89 |
+
# at each iteration instead of the current one (While doing this change
|
| 90 |
+
# for the R-operation is trivial, doing it for the I-operation with the
|
| 91 |
+
# window based USE comparison might be more involved)."
|
| 92 |
+
#
|
| 93 |
+
# Finally, since the BAE code is based on the TextFooler code, we need to
|
| 94 |
+
# adjust the threshold to account for the missing / pi in the cosine
|
| 95 |
+
# similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi
|
| 96 |
+
# = 1 - (0.2 / pi) = 0.936338023.
|
| 97 |
+
use_constraint = UniversalSentenceEncoder(
|
| 98 |
+
threshold=0.936338023,
|
| 99 |
+
metric="cosine",
|
| 100 |
+
compare_against_original=True,
|
| 101 |
+
window_size=15,
|
| 102 |
+
skip_text_shorter_than_window=True,
|
| 103 |
+
)
|
| 104 |
+
constraints.append(use_constraint)
|
| 105 |
+
#
|
| 106 |
+
# Goal is untargeted classification.
|
| 107 |
+
#
|
| 108 |
+
goal_function = UntargetedClassification(model_wrapper)
|
| 109 |
+
#
|
| 110 |
+
# "We estimate the token importance Ii of each token
|
| 111 |
+
# t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
|
| 112 |
+
# decrease in probability of predicting the correct label y, similar
|
| 113 |
+
# to (Jin et al., 2019).
|
| 114 |
+
#
|
| 115 |
+
# • "If there are multiple tokens can cause C to misclassify S when they
|
| 116 |
+
# replace the mask, we choose the token which makes Sadv most similar to
|
| 117 |
+
# the original S based on the USE score."
|
| 118 |
+
# • "If no token causes misclassification, we choose the perturbation that
|
| 119 |
+
# decreases the prediction probability P(C(Sadv)=y) the most."
|
| 120 |
+
#
|
| 121 |
+
search_method = GreedyWordSwapWIR(wir_method="delete")
|
| 122 |
+
|
| 123 |
+
return BAEGarg2019(goal_function, constraints, transformation, search_method)
|