ngohuudang
commited on
Commit
·
1b76ad1
1
Parent(s):
9bb5ff5
update file
Browse files- .gitattributes +1 -9
- __pycache__/gec_model.cpython-310.pyc +0 -0
- __pycache__/gec_model.cpython-311.pyc +0 -0
- __pycache__/gec_model.cpython-39.pyc +0 -0
- __pycache__/modeling_seq2labels.cpython-310.pyc +0 -0
- __pycache__/modeling_seq2labels.cpython-311.pyc +0 -0
- __pycache__/modeling_seq2labels.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- __pycache__/utils_gec.cpython-310.pyc +0 -0
- __pycache__/utils_gec.cpython-311.pyc +0 -0
- __pycache__/utils_gec.cpython-39.pyc +0 -0
- __pycache__/vocabulary.cpython-310.pyc +0 -0
- __pycache__/vocabulary.cpython-311.pyc +0 -0
- __pycache__/vocabulary.cpython-39.pyc +0 -0
- config.json +18 -0
- configuration_seq2labels.py +62 -0
- gec_model.py +449 -0
- modeling_seq2labels.py +124 -0
- pytorch_model.bin +3 -0
- utils_gec.py +233 -0
- verb-form-vocab.txt +0 -0
- vocabulary.py +277 -0
- vocabulary/d_tags.txt +4 -0
- vocabulary/labels.txt +15 -0
- vocabulary/non_padded_namespaces.txt +2 -0
- xlm-roberta-base/config.json +25 -0
- xlm-roberta-base/tokenizer.json +0 -0
.gitattributes
CHANGED
@@ -2,34 +2,26 @@
|
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
5 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
10 |
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
12 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
16 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
19 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
21 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
__pycache__/gec_model.cpython-310.pyc
ADDED
Binary file (14.1 kB). View file
|
|
__pycache__/gec_model.cpython-311.pyc
ADDED
Binary file (25.9 kB). View file
|
|
__pycache__/gec_model.cpython-39.pyc
ADDED
Binary file (14.2 kB). View file
|
|
__pycache__/modeling_seq2labels.cpython-310.pyc
ADDED
Binary file (3.97 kB). View file
|
|
__pycache__/modeling_seq2labels.cpython-311.pyc
ADDED
Binary file (7.03 kB). View file
|
|
__pycache__/modeling_seq2labels.cpython-39.pyc
ADDED
Binary file (4.06 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (6.13 kB). View file
|
|
__pycache__/utils_gec.cpython-310.pyc
ADDED
Binary file (6.14 kB). View file
|
|
__pycache__/utils_gec.cpython-311.pyc
ADDED
Binary file (11.8 kB). View file
|
|
__pycache__/utils_gec.cpython-39.pyc
ADDED
Binary file (6.12 kB). View file
|
|
__pycache__/vocabulary.cpython-310.pyc
ADDED
Binary file (12.9 kB). View file
|
|
__pycache__/vocabulary.cpython-311.pyc
ADDED
Binary file (18.9 kB). View file
|
|
__pycache__/vocabulary.cpython-39.pyc
ADDED
Binary file (13 kB). View file
|
|
config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Seq2LabelsModel"
|
4 |
+
],
|
5 |
+
"initializer_range": 0.02,
|
6 |
+
"label_smoothing": 0.0,
|
7 |
+
"load_pretrained": false,
|
8 |
+
"model_type": "bert",
|
9 |
+
"num_detect_classes": 4,
|
10 |
+
"pad_token_id": 0,
|
11 |
+
"predictor_dropout": 0.0,
|
12 |
+
"pretrained_name_or_path": "xlm-roberta-capu/xlm-roberta-base",
|
13 |
+
"special_tokens_fix": true,
|
14 |
+
"torch_dtype": "float32",
|
15 |
+
"transformers_version": "4.18.0",
|
16 |
+
"use_cache": true,
|
17 |
+
"vocab_size": 15
|
18 |
+
}
|
configuration_seq2labels.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class Seq2LabelsConfig(PretrainedConfig):
|
5 |
+
r"""
|
6 |
+
This is the configuration class to store the configuration of a [`Seq2LabelsModel`]. It is used to
|
7 |
+
instantiate a Seq2Labels model according to the specified arguments, defining the model architecture. Instantiating a
|
8 |
+
configuration with the defaults will yield a similar configuration to that of the Seq2Labels architecture.
|
9 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
10 |
+
documentation from [`PretrainedConfig`] for more information.
|
11 |
+
Args:
|
12 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
13 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
14 |
+
`inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
|
15 |
+
pretrained_name_or_path (`str`, *optional*, defaults to `bert-base-cased`):
|
16 |
+
Pretrained BERT-like model path
|
17 |
+
load_pretrained (`bool`, *optional*, defaults to `False`):
|
18 |
+
Whether to load pretrained model from `pretrained_name_or_path`
|
19 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
20 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
21 |
+
relevant if `config.is_decoder=True`.
|
22 |
+
predictor_dropout (`float`, *optional*):
|
23 |
+
The dropout ratio for the classification head.
|
24 |
+
special_tokens_fix (`bool`, *optional*, defaults to `False`):
|
25 |
+
Whether to add additional tokens to the BERT's embedding layer.
|
26 |
+
Examples:
|
27 |
+
```python
|
28 |
+
>>> from transformers import BertModel, BertConfig
|
29 |
+
>>> # Initializing a Seq2Labels style configuration
|
30 |
+
>>> configuration = Seq2LabelsConfig()
|
31 |
+
>>> # Initializing a model from the bert-base-uncased style configuration
|
32 |
+
>>> model = Seq2LabelsModel(configuration)
|
33 |
+
>>> # Accessing the model configuration
|
34 |
+
>>> configuration = model.config
|
35 |
+
```"""
|
36 |
+
model_type = "bert"
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
pretrained_name_or_path="bert-base-cased",
|
41 |
+
vocab_size=15,
|
42 |
+
num_detect_classes=4,
|
43 |
+
load_pretrained=False,
|
44 |
+
initializer_range=0.02,
|
45 |
+
pad_token_id=0,
|
46 |
+
use_cache=True,
|
47 |
+
predictor_dropout=0.0,
|
48 |
+
special_tokens_fix=False,
|
49 |
+
label_smoothing=0.0,
|
50 |
+
**kwargs
|
51 |
+
):
|
52 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
53 |
+
|
54 |
+
self.vocab_size = vocab_size
|
55 |
+
self.num_detect_classes = num_detect_classes
|
56 |
+
self.pretrained_name_or_path = pretrained_name_or_path
|
57 |
+
self.load_pretrained = load_pretrained
|
58 |
+
self.initializer_range = initializer_range
|
59 |
+
self.use_cache = use_cache
|
60 |
+
self.predictor_dropout = predictor_dropout
|
61 |
+
self.special_tokens_fix = special_tokens_fix
|
62 |
+
self.label_smoothing = label_smoothing
|
gec_model.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Wrapper of Seq2Labels model. Fixes errors based on model predictions"""
|
2 |
+
from collections import defaultdict
|
3 |
+
from difflib import SequenceMatcher
|
4 |
+
import logging
|
5 |
+
import re
|
6 |
+
from time import time
|
7 |
+
from typing import List, Union
|
8 |
+
import warnings
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from modeling_seq2labels import Seq2LabelsModel
|
13 |
+
from vocabulary import Vocabulary
|
14 |
+
from utils_gec import PAD, UNK, START_TOKEN, get_target_sent_by_edits
|
15 |
+
current_dir = sys.path[0].replace('\\','/')
|
16 |
+
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
17 |
+
logger = logging.getLogger(__file__)
|
18 |
+
|
19 |
+
|
20 |
+
class GecBERTModel(torch.nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
vocab_path=None,
|
24 |
+
model_paths=None,
|
25 |
+
weights=None,
|
26 |
+
device=None,
|
27 |
+
max_len=64,
|
28 |
+
min_len=3,
|
29 |
+
lowercase_tokens=False,
|
30 |
+
log=False,
|
31 |
+
iterations=3,
|
32 |
+
min_error_probability=0.0,
|
33 |
+
confidence=0,
|
34 |
+
resolve_cycles=False,
|
35 |
+
split_chunk=False,
|
36 |
+
chunk_size=48,
|
37 |
+
overlap_size=12,
|
38 |
+
min_words_cut=6,
|
39 |
+
punc_dict={':', ".", ",", "?"},
|
40 |
+
):
|
41 |
+
r"""
|
42 |
+
Args:
|
43 |
+
vocab_path (`str`):
|
44 |
+
Path to vocabulary directory.
|
45 |
+
model_paths (`List[str]`):
|
46 |
+
List of model paths.
|
47 |
+
weights (`int`, *Optional*, defaults to None):
|
48 |
+
Weights of each model. Only relevant if `is_ensemble is True`.
|
49 |
+
device (`int`, *Optional*, defaults to None):
|
50 |
+
Device to load model. If not set, device will be automatically choose.
|
51 |
+
max_len (`int`, defaults to 64):
|
52 |
+
Max sentence length to be processed (all longer will be truncated).
|
53 |
+
min_len (`int`, defaults to 3):
|
54 |
+
Min sentence length to be processed (all shorted will be returned w/o changes).
|
55 |
+
lowercase_tokens (`bool`, defaults to False):
|
56 |
+
Whether to lowercase tokens.
|
57 |
+
log (`bool`, defaults to False):
|
58 |
+
Whether to enable logging.
|
59 |
+
iterations (`int`, defaults to 3):
|
60 |
+
Max iterations to run during inference.
|
61 |
+
special_tokens_fix (`bool`, defaults to True):
|
62 |
+
Whether to fix problem with [CLS], [SEP] tokens tokenization.
|
63 |
+
min_error_probability (`float`, defaults to `0.0`):
|
64 |
+
Minimum probability for each action to apply.
|
65 |
+
confidence (`float`, defaults to `0.0`):
|
66 |
+
How many probability to add to $KEEP token.
|
67 |
+
split_chunk (`bool`, defaults to False):
|
68 |
+
Whether to split long sentences to multiple segments of `chunk_size`.
|
69 |
+
!Warning: if `chunk_size > max_len`, each segment will be truncate to `max_len`.
|
70 |
+
chunk_size (`int`, defaults to 48):
|
71 |
+
Length of each segment (in words). Only relevant if `split_chunk is True`.
|
72 |
+
overlap_size (`int`, defaults to 12):
|
73 |
+
Overlap size (in words) between two consecutive segments. Only relevant if `split_chunk is True`.
|
74 |
+
min_words_cut (`int`, defaults to 6):
|
75 |
+
Minimun number of words to be cut while merging two consecutive segments.
|
76 |
+
Only relevant if `split_chunk is True`.
|
77 |
+
punc_dict (List[str], defaults to `{':', ".", ",", "?"}`):
|
78 |
+
List of punctuations.
|
79 |
+
"""
|
80 |
+
super().__init__()
|
81 |
+
if isinstance(model_paths, str):
|
82 |
+
model_paths = [model_paths]
|
83 |
+
self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
|
84 |
+
self.device = (
|
85 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
|
86 |
+
)
|
87 |
+
# self.device = torch.device("cpu")
|
88 |
+
self.max_len = max_len
|
89 |
+
self.min_len = min_len
|
90 |
+
self.lowercase_tokens = lowercase_tokens
|
91 |
+
self.min_error_probability = min_error_probability
|
92 |
+
self.vocab = Vocabulary.from_files(vocab_path)
|
93 |
+
self.incorr_index = self.vocab.get_token_index("INCORRECT", "d_tags")
|
94 |
+
self.log = log
|
95 |
+
self.iterations = iterations
|
96 |
+
self.confidence = confidence
|
97 |
+
self.resolve_cycles = resolve_cycles
|
98 |
+
|
99 |
+
assert (
|
100 |
+
chunk_size > 0 and chunk_size // 2 >= overlap_size
|
101 |
+
), "Chunk merging required overlap size must be smaller than half of chunk size"
|
102 |
+
self.split_chunk = split_chunk
|
103 |
+
self.chunk_size = chunk_size
|
104 |
+
self.overlap_size = overlap_size
|
105 |
+
self.min_words_cut = min_words_cut
|
106 |
+
self.stride = chunk_size - overlap_size
|
107 |
+
self.punc_dict = punc_dict
|
108 |
+
self.punc_str = '[' + ''.join([f'\{x}' for x in punc_dict]) + ']'
|
109 |
+
# set training parameters and operations
|
110 |
+
|
111 |
+
self.indexers = []
|
112 |
+
self.models = []
|
113 |
+
for model_path in model_paths:
|
114 |
+
model = Seq2LabelsModel.from_pretrained(model_path)
|
115 |
+
config = model.config
|
116 |
+
model_name = current_dir + "/" + config.pretrained_name_or_path
|
117 |
+
special_tokens_fix = config.special_tokens_fix
|
118 |
+
self.indexers.append(self._get_indexer(model_name, special_tokens_fix))
|
119 |
+
model.eval().to(self.device)
|
120 |
+
self.models.append(model)
|
121 |
+
|
122 |
+
def _get_indexer(self, weights_name, special_tokens_fix):
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
124 |
+
weights_name, do_basic_tokenize=False,
|
125 |
+
do_lower_case=self.lowercase_tokens, model_max_length=1024
|
126 |
+
)
|
127 |
+
# to adjust all tokenizers
|
128 |
+
if hasattr(tokenizer, 'encoder'):
|
129 |
+
tokenizer.vocab = tokenizer.encoder
|
130 |
+
if hasattr(tokenizer, 'sp_model'):
|
131 |
+
tokenizer.vocab = defaultdict(lambda: 1)
|
132 |
+
for i in range(tokenizer.sp_model.get_piece_size()):
|
133 |
+
tokenizer.vocab[tokenizer.sp_model.id_to_piece(i)] = i
|
134 |
+
|
135 |
+
if special_tokens_fix:
|
136 |
+
tokenizer.add_tokens([START_TOKEN])
|
137 |
+
tokenizer.vocab[START_TOKEN] = len(tokenizer) - 1
|
138 |
+
return tokenizer
|
139 |
+
|
140 |
+
def forward(self, text: Union[str, List[str], List[List[str]]], is_split_into_words=False):
|
141 |
+
# Input type checking for clearer error
|
142 |
+
def _is_valid_text_input(t):
|
143 |
+
if isinstance(t, str):
|
144 |
+
# Strings are fine
|
145 |
+
return True
|
146 |
+
elif isinstance(t, (list, tuple)):
|
147 |
+
# List are fine as long as they are...
|
148 |
+
if len(t) == 0:
|
149 |
+
# ... empty
|
150 |
+
return True
|
151 |
+
elif isinstance(t[0], str):
|
152 |
+
# ... list of strings
|
153 |
+
return True
|
154 |
+
elif isinstance(t[0], (list, tuple)):
|
155 |
+
# ... list with an empty list or with a list of strings
|
156 |
+
return len(t[0]) == 0 or isinstance(t[0][0], str)
|
157 |
+
else:
|
158 |
+
return False
|
159 |
+
else:
|
160 |
+
return False
|
161 |
+
|
162 |
+
if not _is_valid_text_input(text):
|
163 |
+
raise ValueError(
|
164 |
+
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
|
165 |
+
"or `List[List[str]]` (batch of pretokenized examples)."
|
166 |
+
)
|
167 |
+
|
168 |
+
if is_split_into_words:
|
169 |
+
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
|
170 |
+
else:
|
171 |
+
is_batched = isinstance(text, (list, tuple))
|
172 |
+
if is_batched:
|
173 |
+
text = [x.split() for x in text]
|
174 |
+
else:
|
175 |
+
text = text.split()
|
176 |
+
|
177 |
+
if not is_batched:
|
178 |
+
text = [text]
|
179 |
+
|
180 |
+
return self.handle_batch(text)
|
181 |
+
|
182 |
+
def split_chunks(self, batch):
|
183 |
+
# return batch pairs of indices
|
184 |
+
result = []
|
185 |
+
indices = []
|
186 |
+
for tokens in batch:
|
187 |
+
start = len(result)
|
188 |
+
num_token = len(tokens)
|
189 |
+
if num_token <= self.chunk_size:
|
190 |
+
result.append(tokens)
|
191 |
+
elif num_token > self.chunk_size and num_token < (self.chunk_size * 2 - self.overlap_size):
|
192 |
+
split_idx = (num_token + self.overlap_size + 1) // 2
|
193 |
+
result.append(tokens[:split_idx])
|
194 |
+
result.append(tokens[split_idx - self.overlap_size :])
|
195 |
+
else:
|
196 |
+
for i in range(0, num_token - self.overlap_size, self.stride):
|
197 |
+
result.append(tokens[i : i + self.chunk_size])
|
198 |
+
|
199 |
+
indices.append((start, len(result)))
|
200 |
+
|
201 |
+
return result, indices
|
202 |
+
|
203 |
+
def check_alnum(self, s):
|
204 |
+
if len(s) < 2:
|
205 |
+
return False
|
206 |
+
return not (s.isalpha() or s.isdigit())
|
207 |
+
|
208 |
+
def apply_chunk_merging(self, tokens, next_tokens):
|
209 |
+
# Return next tokens if current tokens list is empty
|
210 |
+
if not tokens:
|
211 |
+
return next_tokens
|
212 |
+
|
213 |
+
source_token_idx = []
|
214 |
+
target_token_idx = []
|
215 |
+
source_tokens = []
|
216 |
+
target_tokens = []
|
217 |
+
num_keep = self.overlap_size - self.min_words_cut
|
218 |
+
i = 0
|
219 |
+
while len(source_token_idx) < self.overlap_size and -i < len(tokens):
|
220 |
+
i -= 1
|
221 |
+
if tokens[i] not in self.punc_dict:
|
222 |
+
source_token_idx.insert(0, i)
|
223 |
+
source_tokens.insert(0, tokens[i].lower())
|
224 |
+
|
225 |
+
i = 0
|
226 |
+
while len(target_token_idx) < self.overlap_size and i < len(next_tokens):
|
227 |
+
if next_tokens[i] not in self.punc_dict:
|
228 |
+
target_token_idx.append(i)
|
229 |
+
target_tokens.append(next_tokens[i].lower())
|
230 |
+
i += 1
|
231 |
+
|
232 |
+
matcher = SequenceMatcher(None, source_tokens, target_tokens)
|
233 |
+
diffs = list(matcher.get_opcodes())
|
234 |
+
|
235 |
+
for diff in diffs:
|
236 |
+
tag, i1, i2, j1, j2 = diff
|
237 |
+
if tag == "equal":
|
238 |
+
if i1 >= num_keep:
|
239 |
+
tail_idx = source_token_idx[i1]
|
240 |
+
head_idx = target_token_idx[j1]
|
241 |
+
break
|
242 |
+
elif i2 > num_keep:
|
243 |
+
tail_idx = source_token_idx[num_keep]
|
244 |
+
head_idx = target_token_idx[j2 - i2 + num_keep]
|
245 |
+
break
|
246 |
+
elif tag == "delete" and i1 == 0:
|
247 |
+
num_keep += i2 // 2
|
248 |
+
|
249 |
+
tokens = tokens[:tail_idx] + next_tokens[head_idx:]
|
250 |
+
return tokens
|
251 |
+
|
252 |
+
def merge_chunks(self, batch):
|
253 |
+
result = []
|
254 |
+
if len(batch) == 1 or self.overlap_size == 0:
|
255 |
+
for sub_tokens in batch:
|
256 |
+
result.extend(sub_tokens)
|
257 |
+
else:
|
258 |
+
for _, sub_tokens in enumerate(batch):
|
259 |
+
try:
|
260 |
+
result = self.apply_chunk_merging(result, sub_tokens)
|
261 |
+
except Exception as e:
|
262 |
+
print(e)
|
263 |
+
|
264 |
+
result = " ".join(result)
|
265 |
+
return result
|
266 |
+
|
267 |
+
def predict(self, batches):
|
268 |
+
t11 = time()
|
269 |
+
predictions = []
|
270 |
+
for batch, model in zip(batches, self.models):
|
271 |
+
batch = batch.to(self.device)
|
272 |
+
with torch.no_grad():
|
273 |
+
prediction = model.forward(**batch)
|
274 |
+
predictions.append(prediction)
|
275 |
+
|
276 |
+
preds, idx, error_probs = self._convert(predictions)
|
277 |
+
t55 = time()
|
278 |
+
if self.log:
|
279 |
+
print(f"Inference time {t55 - t11}")
|
280 |
+
return preds, idx, error_probs
|
281 |
+
|
282 |
+
def get_token_action(self, token, index, prob, sugg_token):
|
283 |
+
"""Get lost of suggested actions for token."""
|
284 |
+
# cases when we don't need to do anything
|
285 |
+
if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
|
286 |
+
return None
|
287 |
+
|
288 |
+
if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
|
289 |
+
start_pos = index
|
290 |
+
end_pos = index + 1
|
291 |
+
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
|
292 |
+
start_pos = index + 1
|
293 |
+
end_pos = index + 1
|
294 |
+
|
295 |
+
if sugg_token == "$DELETE":
|
296 |
+
sugg_token_clear = ""
|
297 |
+
elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
|
298 |
+
sugg_token_clear = sugg_token[:]
|
299 |
+
else:
|
300 |
+
sugg_token_clear = sugg_token[sugg_token.index('_') + 1 :]
|
301 |
+
|
302 |
+
return start_pos - 1, end_pos - 1, sugg_token_clear, prob
|
303 |
+
|
304 |
+
def preprocess(self, token_batch):
|
305 |
+
seq_lens = [len(sequence) for sequence in token_batch if sequence]
|
306 |
+
if not seq_lens:
|
307 |
+
return []
|
308 |
+
max_len = min(max(seq_lens), self.max_len)
|
309 |
+
batches = []
|
310 |
+
for indexer in self.indexers:
|
311 |
+
token_batch = [[START_TOKEN] + sequence[:max_len] for sequence in token_batch]
|
312 |
+
batch = indexer(
|
313 |
+
token_batch,
|
314 |
+
return_tensors="pt",
|
315 |
+
padding=True,
|
316 |
+
is_split_into_words=True,
|
317 |
+
truncation=True,
|
318 |
+
add_special_tokens=False,
|
319 |
+
)
|
320 |
+
offset_batch = []
|
321 |
+
for i in range(len(token_batch)):
|
322 |
+
word_ids = batch.word_ids(batch_index=i)
|
323 |
+
offsets = [0]
|
324 |
+
for i in range(1, len(word_ids)):
|
325 |
+
if word_ids[i] != word_ids[i - 1]:
|
326 |
+
offsets.append(i)
|
327 |
+
offset_batch.append(torch.LongTensor(offsets))
|
328 |
+
|
329 |
+
batch["input_offsets"] = torch.nn.utils.rnn.pad_sequence(
|
330 |
+
offset_batch, batch_first=True, padding_value=0
|
331 |
+
).to(torch.long)
|
332 |
+
|
333 |
+
batches.append(batch)
|
334 |
+
|
335 |
+
return batches
|
336 |
+
|
337 |
+
def _convert(self, data):
|
338 |
+
all_class_probs = torch.zeros_like(data[0]['logits'])
|
339 |
+
error_probs = torch.zeros_like(data[0]['max_error_probability'])
|
340 |
+
for output, weight in zip(data, self.model_weights):
|
341 |
+
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
342 |
+
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
343 |
+
class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
|
344 |
+
error_probs_d = class_probabilities_d[:, :, self.incorr_index]
|
345 |
+
incorr_prob = torch.max(error_probs_d, dim=-1)[0]
|
346 |
+
error_probs += weight * incorr_prob / sum(self.model_weights)
|
347 |
+
|
348 |
+
max_vals = torch.max(all_class_probs, dim=-1)
|
349 |
+
probs = max_vals[0].tolist()
|
350 |
+
idx = max_vals[1].tolist()
|
351 |
+
return probs, idx, error_probs.tolist()
|
352 |
+
|
353 |
+
def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict):
|
354 |
+
new_pred_ids = []
|
355 |
+
total_updated = 0
|
356 |
+
for i, orig_id in enumerate(pred_ids):
|
357 |
+
orig = final_batch[orig_id]
|
358 |
+
pred = pred_batch[i]
|
359 |
+
prev_preds = prev_preds_dict[orig_id]
|
360 |
+
if orig != pred and pred not in prev_preds:
|
361 |
+
final_batch[orig_id] = pred
|
362 |
+
new_pred_ids.append(orig_id)
|
363 |
+
prev_preds_dict[orig_id].append(pred)
|
364 |
+
total_updated += 1
|
365 |
+
elif orig != pred and pred in prev_preds:
|
366 |
+
# update final batch, but stop iterations
|
367 |
+
final_batch[orig_id] = pred
|
368 |
+
total_updated += 1
|
369 |
+
else:
|
370 |
+
continue
|
371 |
+
return final_batch, new_pred_ids, total_updated
|
372 |
+
|
373 |
+
def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs):
|
374 |
+
all_results = []
|
375 |
+
noop_index = self.vocab.get_token_index("$KEEP", "labels")
|
376 |
+
for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs):
|
377 |
+
length = min(len(tokens), self.max_len)
|
378 |
+
edits = []
|
379 |
+
|
380 |
+
# skip whole sentences if there no errors
|
381 |
+
if max(idxs) == 0:
|
382 |
+
all_results.append(tokens)
|
383 |
+
continue
|
384 |
+
|
385 |
+
# skip whole sentence if probability of correctness is not high
|
386 |
+
if error_prob < self.min_error_probability:
|
387 |
+
all_results.append(tokens)
|
388 |
+
continue
|
389 |
+
|
390 |
+
for i in range(length + 1):
|
391 |
+
# because of START token
|
392 |
+
if i == 0:
|
393 |
+
token = START_TOKEN
|
394 |
+
else:
|
395 |
+
token = tokens[i - 1]
|
396 |
+
# skip if there is no error
|
397 |
+
if idxs[i] == noop_index:
|
398 |
+
continue
|
399 |
+
|
400 |
+
sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
|
401 |
+
action = self.get_token_action(token, i, probabilities[i], sugg_token)
|
402 |
+
if not action:
|
403 |
+
continue
|
404 |
+
|
405 |
+
edits.append(action)
|
406 |
+
all_results.append(get_target_sent_by_edits(tokens, edits))
|
407 |
+
return all_results
|
408 |
+
|
409 |
+
def handle_batch(self, full_batch, merge_punc=True):
|
410 |
+
"""
|
411 |
+
Handle batch of requests.
|
412 |
+
"""
|
413 |
+
if self.split_chunk:
|
414 |
+
full_batch, indices = self.split_chunks(full_batch)
|
415 |
+
else:
|
416 |
+
indices = None
|
417 |
+
final_batch = full_batch[:]
|
418 |
+
batch_size = len(full_batch)
|
419 |
+
prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
|
420 |
+
short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len]
|
421 |
+
pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
|
422 |
+
total_updates = 0
|
423 |
+
|
424 |
+
for n_iter in range(self.iterations):
|
425 |
+
orig_batch = [final_batch[i] for i in pred_ids]
|
426 |
+
|
427 |
+
sequences = self.preprocess(orig_batch)
|
428 |
+
|
429 |
+
if not sequences:
|
430 |
+
break
|
431 |
+
probabilities, idxs, error_probs = self.predict(sequences)
|
432 |
+
|
433 |
+
pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs)
|
434 |
+
if self.log:
|
435 |
+
print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
|
436 |
+
|
437 |
+
final_batch, pred_ids, cnt = self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict)
|
438 |
+
total_updates += cnt
|
439 |
+
|
440 |
+
if not pred_ids:
|
441 |
+
break
|
442 |
+
if self.split_chunk:
|
443 |
+
final_batch = [self.merge_chunks(final_batch[start:end]) for (start, end) in indices]
|
444 |
+
else:
|
445 |
+
final_batch = [" ".join(x) for x in final_batch]
|
446 |
+
if merge_punc:
|
447 |
+
final_batch = [re.sub(r'\s+(%s)' % self.punc_str, r'\1', x) for x in final_batch]
|
448 |
+
|
449 |
+
return final_batch
|
modeling_seq2labels.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import CrossEntropyLoss
|
4 |
+
from transformers import AutoConfig, AutoModel, BertPreTrainedModel
|
5 |
+
from transformers.modeling_outputs import ModelOutput
|
6 |
+
import sys
|
7 |
+
import torch
|
8 |
+
current_dir = sys.path[0].replace('\\','/')
|
9 |
+
|
10 |
+
def get_range_vector(size: int, device: int) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
Returns a range vector with the desired size, starting at 0. The CUDA implementation
|
13 |
+
is meant to avoid copy data from CPU to GPU.
|
14 |
+
"""
|
15 |
+
return torch.arange(0, size, dtype=torch.long, device=device)
|
16 |
+
|
17 |
+
|
18 |
+
class Seq2LabelsOutput(ModelOutput):
|
19 |
+
loss: Optional[torch.FloatTensor] = None
|
20 |
+
logits: torch.FloatTensor = None
|
21 |
+
detect_logits: torch.FloatTensor = None
|
22 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
23 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
+
max_error_probability: Optional[torch.FloatTensor] = None
|
25 |
+
|
26 |
+
|
27 |
+
class Seq2LabelsModel(BertPreTrainedModel):
|
28 |
+
|
29 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
30 |
+
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__(config)
|
33 |
+
self.num_labels = config.num_labels
|
34 |
+
self.num_detect_classes = config.num_detect_classes
|
35 |
+
self.label_smoothing = config.label_smoothing
|
36 |
+
|
37 |
+
if config.load_pretrained:
|
38 |
+
self.bert = AutoModel.from_pretrained(current_dir + "/" + config.pretrained_name_or_path)
|
39 |
+
bert_config = self.bert.config
|
40 |
+
else:
|
41 |
+
print(current_dir + "/" + config.pretrained_name_or_path)
|
42 |
+
bert_config = AutoConfig.from_pretrained(current_dir + "/" + config.pretrained_name_or_path)
|
43 |
+
self.bert = AutoModel.from_config(bert_config)
|
44 |
+
|
45 |
+
if config.special_tokens_fix:
|
46 |
+
try:
|
47 |
+
vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
|
48 |
+
except AttributeError:
|
49 |
+
# reserve more space
|
50 |
+
vocab_size = self.bert.word_embedding.num_embeddings + 5
|
51 |
+
self.bert.resize_token_embeddings(vocab_size + 1)
|
52 |
+
|
53 |
+
predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
|
54 |
+
self.dropout = nn.Dropout(predictor_dropout)
|
55 |
+
self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
|
56 |
+
self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)
|
57 |
+
|
58 |
+
# Initialize weights and apply final processing
|
59 |
+
self.post_init()
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
input_ids: Optional[torch.Tensor] = None,
|
64 |
+
input_offsets: Optional[torch.Tensor] = None,
|
65 |
+
attention_mask: Optional[torch.Tensor] = None,
|
66 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
67 |
+
position_ids: Optional[torch.Tensor] = None,
|
68 |
+
head_mask: Optional[torch.Tensor] = None,
|
69 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
70 |
+
labels: Optional[torch.Tensor] = None,
|
71 |
+
d_tags: Optional[torch.Tensor] = None,
|
72 |
+
output_attentions: Optional[bool] = None,
|
73 |
+
output_hidden_states: Optional[bool] = None,
|
74 |
+
return_dict: Optional[bool] = None,
|
75 |
+
) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
|
76 |
+
r"""
|
77 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
78 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
79 |
+
"""
|
80 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
81 |
+
|
82 |
+
outputs = self.bert(
|
83 |
+
input_ids,
|
84 |
+
attention_mask=attention_mask,
|
85 |
+
token_type_ids=token_type_ids,
|
86 |
+
position_ids=position_ids,
|
87 |
+
head_mask=head_mask,
|
88 |
+
inputs_embeds=inputs_embeds,
|
89 |
+
output_attentions=output_attentions,
|
90 |
+
output_hidden_states=output_hidden_states,
|
91 |
+
return_dict=return_dict,
|
92 |
+
)
|
93 |
+
|
94 |
+
sequence_output = outputs[0]
|
95 |
+
|
96 |
+
if input_offsets is not None:
|
97 |
+
# offsets is (batch_size, d1, ..., dn, orig_sequence_length)
|
98 |
+
range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
|
99 |
+
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
|
100 |
+
sequence_output = sequence_output[range_vector, input_offsets]
|
101 |
+
|
102 |
+
logits = self.classifier(self.dropout(sequence_output))
|
103 |
+
logits_d = self.detector(sequence_output)
|
104 |
+
|
105 |
+
loss = None
|
106 |
+
if labels is not None and d_tags is not None:
|
107 |
+
loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
|
108 |
+
loss_d_fct = CrossEntropyLoss()
|
109 |
+
loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
110 |
+
loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
|
111 |
+
loss = loss_labels + loss_d
|
112 |
+
|
113 |
+
if not return_dict:
|
114 |
+
output = (logits, logits_d) + outputs[2:]
|
115 |
+
return ((loss,) + output) if loss is not None else output
|
116 |
+
|
117 |
+
return Seq2LabelsOutput(
|
118 |
+
loss=loss,
|
119 |
+
logits=logits,
|
120 |
+
detect_logits=logits_d,
|
121 |
+
hidden_states=outputs.hidden_states,
|
122 |
+
attentions=outputs.attentions,
|
123 |
+
max_error_probability=torch.ones(logits.size(0), device=logits.device),
|
124 |
+
)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6e2a5c2b1cbf16a9fd0b88c0dc8585f3832a60d10eea8140854f8d8f32c188d
|
3 |
+
size 1112304873
|
utils_gec.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
VOCAB_DIR = Path(__file__).resolve().parent
|
7 |
+
PAD = "@@PADDING@@"
|
8 |
+
UNK = "@@UNKNOWN@@"
|
9 |
+
START_TOKEN = "$START"
|
10 |
+
SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"}
|
11 |
+
|
12 |
+
|
13 |
+
def get_verb_form_dicts():
|
14 |
+
path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
|
15 |
+
encode, decode = {}, {}
|
16 |
+
with open(path_to_dict, encoding="utf-8") as f:
|
17 |
+
for line in f:
|
18 |
+
words, tags = line.split(":")
|
19 |
+
word1, word2 = words.split("_")
|
20 |
+
tag1, tag2 = tags.split("_")
|
21 |
+
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
|
22 |
+
if decode_key not in decode:
|
23 |
+
encode[words] = tags
|
24 |
+
decode[decode_key] = word2
|
25 |
+
return encode, decode
|
26 |
+
|
27 |
+
|
28 |
+
ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
|
29 |
+
|
30 |
+
|
31 |
+
def get_target_sent_by_edits(source_tokens, edits):
|
32 |
+
target_tokens = source_tokens[:]
|
33 |
+
shift_idx = 0
|
34 |
+
for edit in edits:
|
35 |
+
start, end, label, _ = edit
|
36 |
+
target_pos = start + shift_idx
|
37 |
+
if start < 0:
|
38 |
+
continue
|
39 |
+
elif len(target_tokens) > target_pos:
|
40 |
+
source_token = target_tokens[target_pos]
|
41 |
+
else:
|
42 |
+
source_token = ""
|
43 |
+
if label == "":
|
44 |
+
del target_tokens[target_pos]
|
45 |
+
shift_idx -= 1
|
46 |
+
elif start == end:
|
47 |
+
word = label.replace("$APPEND_", "")
|
48 |
+
# Avoid appending same token twice
|
49 |
+
if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or (
|
50 |
+
target_pos > 0 and target_tokens[target_pos - 1] == word
|
51 |
+
):
|
52 |
+
continue
|
53 |
+
target_tokens[target_pos:target_pos] = [word]
|
54 |
+
shift_idx += 1
|
55 |
+
elif label.startswith("$TRANSFORM_"):
|
56 |
+
word = apply_reverse_transformation(source_token, label)
|
57 |
+
if word is None:
|
58 |
+
word = source_token
|
59 |
+
target_tokens[target_pos] = word
|
60 |
+
elif start == end - 1:
|
61 |
+
word = label.replace("$REPLACE_", "")
|
62 |
+
target_tokens[target_pos] = word
|
63 |
+
elif label.startswith("$MERGE_"):
|
64 |
+
target_tokens[target_pos + 1 : target_pos + 1] = [label]
|
65 |
+
shift_idx += 1
|
66 |
+
|
67 |
+
return replace_merge_transforms(target_tokens)
|
68 |
+
|
69 |
+
|
70 |
+
def replace_merge_transforms(tokens):
|
71 |
+
if all(not x.startswith("$MERGE_") for x in tokens):
|
72 |
+
return tokens
|
73 |
+
if tokens[0].startswith("$MERGE_"):
|
74 |
+
tokens = tokens[1:]
|
75 |
+
if tokens[-1].startswith("$MERGE_"):
|
76 |
+
tokens = tokens[:-1]
|
77 |
+
|
78 |
+
target_line = " ".join(tokens)
|
79 |
+
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
|
80 |
+
target_line = target_line.replace(" $MERGE_SPACE ", "")
|
81 |
+
target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line)
|
82 |
+
return target_line.split()
|
83 |
+
|
84 |
+
|
85 |
+
def convert_using_case(token, smart_action):
|
86 |
+
if not smart_action.startswith("$TRANSFORM_CASE_"):
|
87 |
+
return token
|
88 |
+
if smart_action.endswith("LOWER"):
|
89 |
+
return token.lower()
|
90 |
+
elif smart_action.endswith("UPPER"):
|
91 |
+
return token.upper()
|
92 |
+
elif smart_action.endswith("CAPITAL"):
|
93 |
+
return token.capitalize()
|
94 |
+
elif smart_action.endswith("CAPITAL_1"):
|
95 |
+
return token[0] + token[1:].capitalize()
|
96 |
+
elif smart_action.endswith("UPPER_-1"):
|
97 |
+
return token[:-1].upper() + token[-1]
|
98 |
+
else:
|
99 |
+
return token
|
100 |
+
|
101 |
+
|
102 |
+
def convert_using_verb(token, smart_action):
|
103 |
+
key_word = "$TRANSFORM_VERB_"
|
104 |
+
if not smart_action.startswith(key_word):
|
105 |
+
raise Exception(f"Unknown action type {smart_action}")
|
106 |
+
encoding_part = f"{token}_{smart_action[len(key_word):]}"
|
107 |
+
decoded_target_word = decode_verb_form(encoding_part)
|
108 |
+
return decoded_target_word
|
109 |
+
|
110 |
+
|
111 |
+
def convert_using_split(token, smart_action):
|
112 |
+
key_word = "$TRANSFORM_SPLIT"
|
113 |
+
if not smart_action.startswith(key_word):
|
114 |
+
raise Exception(f"Unknown action type {smart_action}")
|
115 |
+
target_words = token.split("-")
|
116 |
+
return " ".join(target_words)
|
117 |
+
|
118 |
+
|
119 |
+
def convert_using_plural(token, smart_action):
|
120 |
+
if smart_action.endswith("PLURAL"):
|
121 |
+
return token + "s"
|
122 |
+
elif smart_action.endswith("SINGULAR"):
|
123 |
+
return token[:-1]
|
124 |
+
else:
|
125 |
+
raise Exception(f"Unknown action type {smart_action}")
|
126 |
+
|
127 |
+
|
128 |
+
def apply_reverse_transformation(source_token, transform):
|
129 |
+
if transform.startswith("$TRANSFORM"):
|
130 |
+
# deal with equal
|
131 |
+
if transform == "$KEEP":
|
132 |
+
return source_token
|
133 |
+
# deal with case
|
134 |
+
if transform.startswith("$TRANSFORM_CASE"):
|
135 |
+
return convert_using_case(source_token, transform)
|
136 |
+
# deal with verb
|
137 |
+
if transform.startswith("$TRANSFORM_VERB"):
|
138 |
+
return convert_using_verb(source_token, transform)
|
139 |
+
# deal with split
|
140 |
+
if transform.startswith("$TRANSFORM_SPLIT"):
|
141 |
+
return convert_using_split(source_token, transform)
|
142 |
+
# deal with single/plural
|
143 |
+
if transform.startswith("$TRANSFORM_AGREEMENT"):
|
144 |
+
return convert_using_plural(source_token, transform)
|
145 |
+
# raise exception if not find correct type
|
146 |
+
raise Exception(f"Unknown action type {transform}")
|
147 |
+
else:
|
148 |
+
return source_token
|
149 |
+
|
150 |
+
|
151 |
+
# def read_parallel_lines(fn1, fn2):
|
152 |
+
# lines1 = read_lines(fn1, skip_strip=True)
|
153 |
+
# lines2 = read_lines(fn2, skip_strip=True)
|
154 |
+
# assert len(lines1) == len(lines2)
|
155 |
+
# out_lines1, out_lines2 = [], []
|
156 |
+
# for line1, line2 in zip(lines1, lines2):
|
157 |
+
# if not line1.strip() or not line2.strip():
|
158 |
+
# continue
|
159 |
+
# else:
|
160 |
+
# out_lines1.append(line1)
|
161 |
+
# out_lines2.append(line2)
|
162 |
+
# return out_lines1, out_lines2
|
163 |
+
|
164 |
+
|
165 |
+
def read_parallel_lines(fn1, fn2):
|
166 |
+
with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2:
|
167 |
+
for line1, line2 in zip(f1, f2):
|
168 |
+
line1 = line1.strip()
|
169 |
+
line2 = line2.strip()
|
170 |
+
|
171 |
+
yield line1, line2
|
172 |
+
|
173 |
+
|
174 |
+
def read_lines(fn, skip_strip=False):
|
175 |
+
if not os.path.exists(fn):
|
176 |
+
return []
|
177 |
+
with open(fn, 'r', encoding='utf-8') as f:
|
178 |
+
lines = f.readlines()
|
179 |
+
return [s.strip() for s in lines if s.strip() or skip_strip]
|
180 |
+
|
181 |
+
|
182 |
+
def write_lines(fn, lines, mode='w'):
|
183 |
+
if mode == 'w' and os.path.exists(fn):
|
184 |
+
os.remove(fn)
|
185 |
+
with open(fn, encoding='utf-8', mode=mode) as f:
|
186 |
+
f.writelines(['%s\n' % s for s in lines])
|
187 |
+
|
188 |
+
|
189 |
+
def decode_verb_form(original):
|
190 |
+
return DECODE_VERB_DICT.get(original)
|
191 |
+
|
192 |
+
|
193 |
+
def encode_verb_form(original_word, corrected_word):
|
194 |
+
decoding_request = original_word + "_" + corrected_word
|
195 |
+
decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
|
196 |
+
if original_word and decoding_response:
|
197 |
+
answer = decoding_response
|
198 |
+
else:
|
199 |
+
answer = None
|
200 |
+
return answer
|
201 |
+
|
202 |
+
|
203 |
+
def get_weights_name(transformer_name, lowercase):
|
204 |
+
if transformer_name == 'bert' and lowercase:
|
205 |
+
return 'bert-base-uncased'
|
206 |
+
if transformer_name == 'bert' and not lowercase:
|
207 |
+
return 'bert-base-cased'
|
208 |
+
if transformer_name == 'bert-large' and not lowercase:
|
209 |
+
return 'bert-large-cased'
|
210 |
+
if transformer_name == 'distilbert':
|
211 |
+
if not lowercase:
|
212 |
+
print('Warning! This model was trained only on uncased sentences.')
|
213 |
+
return 'distilbert-base-uncased'
|
214 |
+
if transformer_name == 'albert':
|
215 |
+
if not lowercase:
|
216 |
+
print('Warning! This model was trained only on uncased sentences.')
|
217 |
+
return 'albert-base-v1'
|
218 |
+
if lowercase:
|
219 |
+
print('Warning! This model was trained only on cased sentences.')
|
220 |
+
if transformer_name == 'roberta':
|
221 |
+
return 'roberta-base'
|
222 |
+
if transformer_name == 'roberta-large':
|
223 |
+
return 'roberta-large'
|
224 |
+
if transformer_name == 'gpt2':
|
225 |
+
return 'gpt2'
|
226 |
+
if transformer_name == 'transformerxl':
|
227 |
+
return 'transfo-xl-wt103'
|
228 |
+
if transformer_name == 'xlnet':
|
229 |
+
return 'xlnet-base-cased'
|
230 |
+
if transformer_name == 'xlnet-large':
|
231 |
+
return 'xlnet-large-cased'
|
232 |
+
|
233 |
+
return transformer_name
|
verb-form-vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vocabulary.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs
|
2 |
+
from collections import defaultdict
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union, TYPE_CHECKING
|
7 |
+
from filelock import FileLock
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
|
13 |
+
DEFAULT_PADDING_TOKEN = "@@PADDING@@"
|
14 |
+
DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
|
15 |
+
NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
|
16 |
+
_NEW_LINE_REGEX = re.compile(r"\n|\r\n")
|
17 |
+
|
18 |
+
|
19 |
+
def namespace_match(pattern: str, namespace: str):
|
20 |
+
"""
|
21 |
+
Matches a namespace pattern against a namespace string. For example, `*tags` matches
|
22 |
+
`passage_tags` and `question_tags` and `tokens` matches `tokens` but not
|
23 |
+
`stemmed_tokens`.
|
24 |
+
"""
|
25 |
+
if pattern[0] == "*" and namespace.endswith(pattern[1:]):
|
26 |
+
return True
|
27 |
+
elif pattern == namespace:
|
28 |
+
return True
|
29 |
+
return False
|
30 |
+
|
31 |
+
|
32 |
+
class _NamespaceDependentDefaultDict(defaultdict):
|
33 |
+
"""
|
34 |
+
This is a [defaultdict]
|
35 |
+
(https://docs.python.org/2/library/collections.html#collections.defaultdict) where the
|
36 |
+
default value is dependent on the key that is passed.
|
37 |
+
We use "namespaces" in the :class:`Vocabulary` object to keep track of several different
|
38 |
+
mappings from strings to integers, so that we have a consistent API for mapping words, tags,
|
39 |
+
labels, characters, or whatever else you want, into integers. The issue is that some of those
|
40 |
+
namespaces (words and characters) should have integers reserved for padding and
|
41 |
+
out-of-vocabulary tokens, while others (labels and tags) shouldn't. This class allows you to
|
42 |
+
specify filters on the namespace (the key used in the `defaultdict`), and use different
|
43 |
+
default values depending on whether the namespace passes the filter.
|
44 |
+
To do filtering, we take a set of `non_padded_namespaces`. This is a set of strings
|
45 |
+
that are either matched exactly against the keys, or treated as suffixes, if the
|
46 |
+
string starts with `*`. In other words, if `*tags` is in `non_padded_namespaces` then
|
47 |
+
`passage_tags`, `question_tags`, etc. (anything that ends with `tags`) will have the
|
48 |
+
`non_padded` default value.
|
49 |
+
# Parameters
|
50 |
+
non_padded_namespaces : `Iterable[str]`
|
51 |
+
A set / list / tuple of strings describing which namespaces are not padded. If a namespace
|
52 |
+
(key) is missing from this dictionary, we will use :func:`namespace_match` to see whether
|
53 |
+
the namespace should be padded. If the given namespace matches any of the strings in this
|
54 |
+
list, we will use `non_padded_function` to initialize the value for that namespace, and
|
55 |
+
we will use `padded_function` otherwise.
|
56 |
+
padded_function : `Callable[[], Any]`
|
57 |
+
A zero-argument function to call to initialize a value for a namespace that `should` be
|
58 |
+
padded.
|
59 |
+
non_padded_function : `Callable[[], Any]`
|
60 |
+
A zero-argument function to call to initialize a value for a namespace that should `not` be
|
61 |
+
padded.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
non_padded_namespaces: Iterable[str],
|
67 |
+
padded_function: Callable[[], Any],
|
68 |
+
non_padded_function: Callable[[], Any],
|
69 |
+
) -> None:
|
70 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
71 |
+
self._padded_function = padded_function
|
72 |
+
self._non_padded_function = non_padded_function
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
|
76 |
+
# add non_padded_namespaces which weren't already present
|
77 |
+
self._non_padded_namespaces.update(non_padded_namespaces)
|
78 |
+
|
79 |
+
|
80 |
+
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
|
81 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
82 |
+
super().__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {})
|
83 |
+
|
84 |
+
|
85 |
+
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
|
86 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
87 |
+
super().__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {})
|
88 |
+
|
89 |
+
|
90 |
+
class Vocabulary:
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
counter: Dict[str, Dict[str, int]] = None,
|
94 |
+
min_count: Dict[str, int] = None,
|
95 |
+
max_vocab_size: Union[int, Dict[str, int]] = None,
|
96 |
+
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
|
97 |
+
pretrained_files: Optional[Dict[str, str]] = None,
|
98 |
+
only_include_pretrained_words: bool = False,
|
99 |
+
tokens_to_add: Dict[str, List[str]] = None,
|
100 |
+
min_pretrained_embeddings: Dict[str, int] = None,
|
101 |
+
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
|
102 |
+
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
|
103 |
+
) -> None:
|
104 |
+
self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
|
105 |
+
self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
|
106 |
+
|
107 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
108 |
+
|
109 |
+
self._token_to_index = _TokenToIndexDefaultDict(
|
110 |
+
self._non_padded_namespaces, self._padding_token, self._oov_token
|
111 |
+
)
|
112 |
+
self._index_to_token = _IndexToTokenDefaultDict(
|
113 |
+
self._non_padded_namespaces, self._padding_token, self._oov_token
|
114 |
+
)
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def from_files(
|
118 |
+
cls,
|
119 |
+
directory: Union[str, os.PathLike],
|
120 |
+
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
|
121 |
+
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
|
122 |
+
) -> "Vocabulary":
|
123 |
+
"""
|
124 |
+
Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
|
125 |
+
a model archive file.
|
126 |
+
# Parameters
|
127 |
+
directory : `str`
|
128 |
+
The directory or archive file containing the serialized vocabulary.
|
129 |
+
"""
|
130 |
+
logger.info("Loading token dictionary from %s.", directory)
|
131 |
+
padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
|
132 |
+
oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
|
133 |
+
|
134 |
+
if not os.path.isdir(directory):
|
135 |
+
raise ValueError(f"{directory} not exist")
|
136 |
+
|
137 |
+
# We use a lock file to avoid race conditions where multiple processes
|
138 |
+
# might be reading/writing from/to the same vocab files at once.
|
139 |
+
with FileLock(os.path.join(directory, ".lock")):
|
140 |
+
with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8") as namespace_file:
|
141 |
+
non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
|
142 |
+
|
143 |
+
vocab = cls(
|
144 |
+
non_padded_namespaces=non_padded_namespaces,
|
145 |
+
padding_token=padding_token,
|
146 |
+
oov_token=oov_token,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Check every file in the directory.
|
150 |
+
for namespace_filename in os.listdir(directory):
|
151 |
+
if namespace_filename == NAMESPACE_PADDING_FILE:
|
152 |
+
continue
|
153 |
+
if namespace_filename.startswith("."):
|
154 |
+
continue
|
155 |
+
namespace = namespace_filename.replace(".txt", "")
|
156 |
+
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
|
157 |
+
is_padded = False
|
158 |
+
else:
|
159 |
+
is_padded = True
|
160 |
+
filename = os.path.join(directory, namespace_filename)
|
161 |
+
vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
|
162 |
+
|
163 |
+
return vocab
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def empty(cls) -> "Vocabulary":
|
167 |
+
"""
|
168 |
+
This method returns a bare vocabulary instantiated with `cls()` (so, `Vocabulary()` if you
|
169 |
+
haven't made a subclass of this object). The only reason to call `Vocabulary.empty()`
|
170 |
+
instead of `Vocabulary()` is if you are instantiating this object from a config file. We
|
171 |
+
register this constructor with the key "empty", so if you know that you don't need to
|
172 |
+
compute a vocabulary (either because you're loading a pre-trained model from an archive
|
173 |
+
file, you're using a pre-trained transformer that has its own vocabulary, or something
|
174 |
+
else), you can use this to avoid having the default vocabulary construction code iterate
|
175 |
+
through the data.
|
176 |
+
"""
|
177 |
+
return cls()
|
178 |
+
|
179 |
+
def set_from_file(
|
180 |
+
self,
|
181 |
+
filename: str,
|
182 |
+
is_padded: bool = True,
|
183 |
+
oov_token: str = DEFAULT_OOV_TOKEN,
|
184 |
+
namespace: str = "tokens",
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
If you already have a vocabulary file for a trained model somewhere, and you really want to
|
188 |
+
use that vocabulary file instead of just setting the vocabulary from a dataset, for
|
189 |
+
whatever reason, you can do that with this method. You must specify the namespace to use,
|
190 |
+
and we assume that you want to use padding and OOV tokens for this.
|
191 |
+
# Parameters
|
192 |
+
filename : `str`
|
193 |
+
The file containing the vocabulary to load. It should be formatted as one token per
|
194 |
+
line, with nothing else in the line. The index we assign to the token is the line
|
195 |
+
number in the file (1-indexed if `is_padded`, 0-indexed otherwise). Note that this
|
196 |
+
file should contain the OOV token string!
|
197 |
+
is_padded : `bool`, optional (default=`True`)
|
198 |
+
Is this vocabulary padded? For token / word / character vocabularies, this should be
|
199 |
+
`True`; while for tag or label vocabularies, this should typically be `False`. If
|
200 |
+
`True`, we add a padding token with index 0, and we enforce that the `oov_token` is
|
201 |
+
present in the file.
|
202 |
+
oov_token : `str`, optional (default=`DEFAULT_OOV_TOKEN`)
|
203 |
+
What token does this vocabulary use to represent out-of-vocabulary characters? This
|
204 |
+
must show up as a line in the vocabulary file. When we find it, we replace
|
205 |
+
`oov_token` with `self._oov_token`, because we only use one OOV token across
|
206 |
+
namespaces.
|
207 |
+
namespace : `str`, optional (default=`"tokens"`)
|
208 |
+
What namespace should we overwrite with this vocab file?
|
209 |
+
"""
|
210 |
+
if is_padded:
|
211 |
+
self._token_to_index[namespace] = {self._padding_token: 0}
|
212 |
+
self._index_to_token[namespace] = {0: self._padding_token}
|
213 |
+
else:
|
214 |
+
self._token_to_index[namespace] = {}
|
215 |
+
self._index_to_token[namespace] = {}
|
216 |
+
with codecs.open(filename, "r", "utf-8") as input_file:
|
217 |
+
lines = _NEW_LINE_REGEX.split(input_file.read())
|
218 |
+
# Be flexible about having final newline or not
|
219 |
+
if lines and lines[-1] == "":
|
220 |
+
lines = lines[:-1]
|
221 |
+
for i, line in enumerate(lines):
|
222 |
+
index = i + 1 if is_padded else i
|
223 |
+
token = line.replace("@@NEWLINE@@", "\n")
|
224 |
+
if token == oov_token:
|
225 |
+
token = self._oov_token
|
226 |
+
self._token_to_index[namespace][token] = index
|
227 |
+
self._index_to_token[namespace][index] = token
|
228 |
+
if is_padded:
|
229 |
+
assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
|
230 |
+
|
231 |
+
def add_token_to_namespace(self, token: str, namespace: str = "tokens") -> int:
|
232 |
+
"""
|
233 |
+
Adds `token` to the index, if it is not already present. Either way, we return the index of
|
234 |
+
the token.
|
235 |
+
"""
|
236 |
+
if not isinstance(token, str):
|
237 |
+
raise ValueError(
|
238 |
+
"Vocabulary tokens must be strings, or saving and loading will break."
|
239 |
+
" Got %s (with type %s)" % (repr(token), type(token))
|
240 |
+
)
|
241 |
+
if token not in self._token_to_index[namespace]:
|
242 |
+
index = len(self._token_to_index[namespace])
|
243 |
+
self._token_to_index[namespace][token] = index
|
244 |
+
self._index_to_token[namespace][index] = token
|
245 |
+
return index
|
246 |
+
else:
|
247 |
+
return self._token_to_index[namespace][token]
|
248 |
+
|
249 |
+
def add_tokens_to_namespace(self, tokens: List[str], namespace: str = "tokens") -> List[int]:
|
250 |
+
"""
|
251 |
+
Adds `tokens` to the index, if they are not already present. Either way, we return the
|
252 |
+
indices of the tokens in the order that they were given.
|
253 |
+
"""
|
254 |
+
return [self.add_token_to_namespace(token, namespace) for token in tokens]
|
255 |
+
|
256 |
+
def get_token_index(self, token: str, namespace: str = "tokens") -> int:
|
257 |
+
try:
|
258 |
+
return self._token_to_index[namespace][token]
|
259 |
+
except KeyError:
|
260 |
+
try:
|
261 |
+
return self._token_to_index[namespace][self._oov_token]
|
262 |
+
except KeyError:
|
263 |
+
logger.error("Namespace: %s", namespace)
|
264 |
+
logger.error("Token: %s", token)
|
265 |
+
raise KeyError(
|
266 |
+
f"'{token}' not found in vocab namespace '{namespace}', and namespace "
|
267 |
+
f"does not contain the default OOV token ('{self._oov_token}')"
|
268 |
+
)
|
269 |
+
|
270 |
+
def get_token_from_index(self, index: int, namespace: str = "tokens") -> str:
|
271 |
+
return self._index_to_token[namespace][index]
|
272 |
+
|
273 |
+
def get_vocab_size(self, namespace: str = "tokens") -> int:
|
274 |
+
return len(self._token_to_index[namespace])
|
275 |
+
|
276 |
+
def get_namespaces(self) -> Set[str]:
|
277 |
+
return set(self._index_to_token.keys())
|
vocabulary/d_tags.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CORRECT
|
2 |
+
INCORRECT
|
3 |
+
@@UNKNOWN@@
|
4 |
+
@@PADDING@@
|
vocabulary/labels.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$KEEP
|
2 |
+
$TRANSFORM_CASE_CAPITAL
|
3 |
+
$APPEND_,
|
4 |
+
$APPEND_.
|
5 |
+
$TRANSFORM_VERB_VB_VBN
|
6 |
+
$TRANSFORM_CASE_UPPER
|
7 |
+
$APPEND_:
|
8 |
+
$APPEND_?
|
9 |
+
$TRANSFORM_VERB_VB_VBC
|
10 |
+
$TRANSFORM_CASE_LOWER
|
11 |
+
$TRANSFORM_CASE_CAPITAL_1
|
12 |
+
$TRANSFORM_CASE_UPPER_-1
|
13 |
+
$MERGE_SPACE
|
14 |
+
@@UNKNOWN@@
|
15 |
+
@@PADDING@@
|
vocabulary/non_padded_namespaces.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*tags
|
2 |
+
*labels
|
xlm-roberta-base/config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"XLMRobertaForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"bos_token_id": 0,
|
7 |
+
"eos_token_id": 2,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"intermediate_size": 3072,
|
13 |
+
"layer_norm_eps": 1e-05,
|
14 |
+
"max_position_embeddings": 514,
|
15 |
+
"model_type": "xlm-roberta",
|
16 |
+
"num_attention_heads": 12,
|
17 |
+
"num_hidden_layers": 12,
|
18 |
+
"output_past": true,
|
19 |
+
"pad_token_id": 1,
|
20 |
+
"position_embedding_type": "absolute",
|
21 |
+
"transformers_version": "4.17.0.dev0",
|
22 |
+
"type_vocab_size": 1,
|
23 |
+
"use_cache": true,
|
24 |
+
"vocab_size": 250002
|
25 |
+
}
|
xlm-roberta-base/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|