Spaces:
Runtime error
Runtime error
Jean Garcia-Gathright
commited on
Commit
•
a02c788
1
Parent(s):
4150cb0
added ernie files
Browse files- app.py +2 -2
- app.py~ +7 -0
- ernie/__init__.py +47 -0
- ernie/aggregation_strategies.py +70 -0
- ernie/ernie.py +397 -0
- ernie/helper.py +121 -0
- ernie/models.py +51 -0
- ernie/split_strategies.py +125 -0
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
import
|
4 |
|
5 |
def greet(name):
|
6 |
return "Hello " + name + "!!"
|
|
|
1 |
import gradio as gr
|
2 |
+
from ernie.ernie import SentenceClassifier
|
3 |
+
from ernie import helper
|
4 |
|
5 |
def greet(name):
|
6 |
return "Hello " + name + "!!"
|
app.py~
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
+
iface.launch()
|
ernie/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from .ernie import * # noqa: F401, F403
|
5 |
+
from tensorflow.python.client import device_lib
|
6 |
+
import logging
|
7 |
+
|
8 |
+
__version__ = '1.0.1'
|
9 |
+
|
10 |
+
logging.getLogger().setLevel(logging.WARNING)
|
11 |
+
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
|
12 |
+
logging.basicConfig(
|
13 |
+
format='%(asctime)-15s [%(levelname)s] %(message)s',
|
14 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def _get_cpu_name():
|
19 |
+
import cpuinfo
|
20 |
+
cpu_info = cpuinfo.get_cpu_info()
|
21 |
+
cpu_name = f"{cpu_info['brand_raw']}, {cpu_info['count']} vCores"
|
22 |
+
return cpu_name
|
23 |
+
|
24 |
+
|
25 |
+
def _get_gpu_name():
|
26 |
+
gpu_name = \
|
27 |
+
device_lib\
|
28 |
+
.list_local_devices()[3]\
|
29 |
+
.physical_device_desc\
|
30 |
+
.split(',')[1]\
|
31 |
+
.split('name:')[1]\
|
32 |
+
.strip()
|
33 |
+
return gpu_name
|
34 |
+
|
35 |
+
|
36 |
+
device_name = _get_cpu_name()
|
37 |
+
device_type = 'CPU'
|
38 |
+
|
39 |
+
try:
|
40 |
+
device_name = _get_gpu_name()
|
41 |
+
device_type = 'GPU'
|
42 |
+
except IndexError:
|
43 |
+
# Detect TPU
|
44 |
+
pass
|
45 |
+
|
46 |
+
logging.info(f'ernie v{__version__}')
|
47 |
+
logging.info(f'target device: [{device_type}] {device_name}\n')
|
ernie/aggregation_strategies.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from statistics import mean
|
5 |
+
|
6 |
+
|
7 |
+
class AggregationStrategy:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
method,
|
11 |
+
max_items=None,
|
12 |
+
top_items=True,
|
13 |
+
sorting_class_index=1
|
14 |
+
):
|
15 |
+
self.method = method
|
16 |
+
self.max_items = max_items
|
17 |
+
self.top_items = top_items
|
18 |
+
self.sorting_class_index = sorting_class_index
|
19 |
+
|
20 |
+
def aggregate(self, softmax_tuples):
|
21 |
+
softmax_dicts = []
|
22 |
+
for softmax_tuple in softmax_tuples:
|
23 |
+
softmax_dict = {}
|
24 |
+
for i, probability in enumerate(softmax_tuple):
|
25 |
+
softmax_dict[i] = probability
|
26 |
+
softmax_dicts.append(softmax_dict)
|
27 |
+
|
28 |
+
if self.max_items is not None:
|
29 |
+
softmax_dicts = sorted(
|
30 |
+
softmax_dicts,
|
31 |
+
key=lambda x: x[self.sorting_class_index],
|
32 |
+
reverse=self.top_items
|
33 |
+
)
|
34 |
+
if self.max_items < len(softmax_dicts):
|
35 |
+
softmax_dicts = softmax_dicts[:self.max_items]
|
36 |
+
|
37 |
+
softmax_list = []
|
38 |
+
for key in softmax_dicts[0].keys():
|
39 |
+
softmax_list.append(self.method(
|
40 |
+
[probabilities[key] for probabilities in softmax_dicts]))
|
41 |
+
softmax_tuple = tuple(softmax_list)
|
42 |
+
return softmax_tuple
|
43 |
+
|
44 |
+
|
45 |
+
class AggregationStrategies:
|
46 |
+
Mean = AggregationStrategy(method=mean)
|
47 |
+
MeanTopFiveBinaryClassification = AggregationStrategy(
|
48 |
+
method=mean,
|
49 |
+
max_items=5,
|
50 |
+
top_items=True,
|
51 |
+
sorting_class_index=1
|
52 |
+
)
|
53 |
+
MeanTopTenBinaryClassification = AggregationStrategy(
|
54 |
+
method=mean,
|
55 |
+
max_items=10,
|
56 |
+
top_items=True,
|
57 |
+
sorting_class_index=1
|
58 |
+
)
|
59 |
+
MeanTopFifteenBinaryClassification = AggregationStrategy(
|
60 |
+
method=mean,
|
61 |
+
max_items=15,
|
62 |
+
top_items=True,
|
63 |
+
sorting_class_index=1
|
64 |
+
)
|
65 |
+
MeanTopTwentyBinaryClassification = AggregationStrategy(
|
66 |
+
method=mean,
|
67 |
+
max_items=20,
|
68 |
+
top_items=True,
|
69 |
+
sorting_class_index=1
|
70 |
+
)
|
ernie/ernie.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from transformers import (
|
7 |
+
AutoTokenizer,
|
8 |
+
AutoModel,
|
9 |
+
AutoConfig,
|
10 |
+
TFAutoModelForSequenceClassification,
|
11 |
+
)
|
12 |
+
from tensorflow import keras
|
13 |
+
from sklearn.model_selection import train_test_split
|
14 |
+
import logging
|
15 |
+
import time
|
16 |
+
from .models import Models, ModelsByFamily # noqa: F401
|
17 |
+
from .split_strategies import ( # noqa: F401
|
18 |
+
SplitStrategy,
|
19 |
+
SplitStrategies,
|
20 |
+
RegexExpressions
|
21 |
+
)
|
22 |
+
from .aggregation_strategies import ( # noqa: F401
|
23 |
+
AggregationStrategy,
|
24 |
+
AggregationStrategies
|
25 |
+
)
|
26 |
+
from .helper import (
|
27 |
+
get_features,
|
28 |
+
softmax,
|
29 |
+
remove_dir,
|
30 |
+
make_dir,
|
31 |
+
copy_dir
|
32 |
+
)
|
33 |
+
|
34 |
+
AUTOSAVE_PATH = './ernie-autosave/'
|
35 |
+
|
36 |
+
|
37 |
+
def clean_autosave():
|
38 |
+
remove_dir(AUTOSAVE_PATH)
|
39 |
+
|
40 |
+
|
41 |
+
class SentenceClassifier:
|
42 |
+
def __init__(self,
|
43 |
+
model_name=Models.BertBaseUncased,
|
44 |
+
model_path=None,
|
45 |
+
max_length=64,
|
46 |
+
labels_no=2,
|
47 |
+
tokenizer_kwargs=None,
|
48 |
+
model_kwargs=None):
|
49 |
+
self._loaded_data = False
|
50 |
+
self._model_path = None
|
51 |
+
|
52 |
+
if model_kwargs is None:
|
53 |
+
model_kwargs = {}
|
54 |
+
model_kwargs['num_labels'] = labels_no
|
55 |
+
|
56 |
+
if tokenizer_kwargs is None:
|
57 |
+
tokenizer_kwargs = {}
|
58 |
+
tokenizer_kwargs['max_len'] = max_length
|
59 |
+
|
60 |
+
if model_path is not None:
|
61 |
+
self._load_local_model(model_path)
|
62 |
+
else:
|
63 |
+
self._load_remote_model(model_name, tokenizer_kwargs, model_kwargs)
|
64 |
+
|
65 |
+
@property
|
66 |
+
def model(self):
|
67 |
+
return self._model
|
68 |
+
|
69 |
+
@property
|
70 |
+
def tokenizer(self):
|
71 |
+
return self._tokenizer
|
72 |
+
|
73 |
+
def load_dataset(self,
|
74 |
+
dataframe=None,
|
75 |
+
validation_split=0.1,
|
76 |
+
random_state=None,
|
77 |
+
stratify=None,
|
78 |
+
csv_path=None,
|
79 |
+
read_csv_kwargs=None):
|
80 |
+
|
81 |
+
if dataframe is None and csv_path is None:
|
82 |
+
raise ValueError
|
83 |
+
|
84 |
+
if csv_path is not None:
|
85 |
+
dataframe = pd.read_csv(csv_path, **read_csv_kwargs)
|
86 |
+
|
87 |
+
sentences = list(dataframe[dataframe.columns[0]])
|
88 |
+
labels = dataframe[dataframe.columns[1]].values
|
89 |
+
|
90 |
+
(
|
91 |
+
training_sentences,
|
92 |
+
validation_sentences,
|
93 |
+
training_labels,
|
94 |
+
validation_labels
|
95 |
+
) = train_test_split(
|
96 |
+
sentences,
|
97 |
+
labels,
|
98 |
+
test_size=validation_split,
|
99 |
+
shuffle=True,
|
100 |
+
random_state=random_state,
|
101 |
+
stratify=stratify
|
102 |
+
)
|
103 |
+
|
104 |
+
self._training_features = get_features(
|
105 |
+
self._tokenizer, training_sentences, training_labels)
|
106 |
+
|
107 |
+
self._training_size = len(training_sentences)
|
108 |
+
|
109 |
+
self._validation_features = get_features(
|
110 |
+
self._tokenizer,
|
111 |
+
validation_sentences,
|
112 |
+
validation_labels
|
113 |
+
)
|
114 |
+
self._validation_split = len(validation_sentences)
|
115 |
+
|
116 |
+
logging.info(f'training_size: {self._training_size}')
|
117 |
+
logging.info(f'validation_split: {self._validation_split}')
|
118 |
+
|
119 |
+
self._loaded_data = True
|
120 |
+
|
121 |
+
def fine_tune(self,
|
122 |
+
epochs=4,
|
123 |
+
learning_rate=2e-5,
|
124 |
+
epsilon=1e-8,
|
125 |
+
clipnorm=1.0,
|
126 |
+
optimizer_function=keras.optimizers.Adam,
|
127 |
+
optimizer_kwargs=None,
|
128 |
+
loss_function=keras.losses.SparseCategoricalCrossentropy,
|
129 |
+
loss_kwargs=None,
|
130 |
+
accuracy_function=keras.metrics.SparseCategoricalAccuracy,
|
131 |
+
accuracy_kwargs=None,
|
132 |
+
training_batch_size=32,
|
133 |
+
validation_batch_size=64,
|
134 |
+
**kwargs):
|
135 |
+
if not self._loaded_data:
|
136 |
+
raise Exception('Data has not been loaded.')
|
137 |
+
|
138 |
+
if optimizer_kwargs is None:
|
139 |
+
optimizer_kwargs = {
|
140 |
+
'learning_rate': learning_rate,
|
141 |
+
'epsilon': epsilon,
|
142 |
+
'clipnorm': clipnorm
|
143 |
+
}
|
144 |
+
optimizer = optimizer_function(**optimizer_kwargs)
|
145 |
+
|
146 |
+
if loss_kwargs is None:
|
147 |
+
loss_kwargs = {'from_logits': True}
|
148 |
+
loss = loss_function(**loss_kwargs)
|
149 |
+
|
150 |
+
if accuracy_kwargs is None:
|
151 |
+
accuracy_kwargs = {'name': 'accuracy'}
|
152 |
+
accuracy = accuracy_function(**accuracy_kwargs)
|
153 |
+
|
154 |
+
self._model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy])
|
155 |
+
|
156 |
+
training_features = self._training_features.shuffle(
|
157 |
+
self._training_size).batch(training_batch_size).repeat(-1)
|
158 |
+
validation_features = self._validation_features.batch(
|
159 |
+
validation_batch_size)
|
160 |
+
|
161 |
+
training_steps = self._training_size // training_batch_size
|
162 |
+
if training_steps == 0:
|
163 |
+
training_steps = self._training_size
|
164 |
+
logging.info(f'training_steps: {training_steps}')
|
165 |
+
|
166 |
+
validation_steps = self._validation_split // validation_batch_size
|
167 |
+
if validation_steps == 0:
|
168 |
+
validation_steps = self._validation_split
|
169 |
+
logging.info(f'validation_steps: {validation_steps}')
|
170 |
+
|
171 |
+
for i in range(epochs):
|
172 |
+
self._model.fit(training_features,
|
173 |
+
epochs=1,
|
174 |
+
validation_data=validation_features,
|
175 |
+
steps_per_epoch=training_steps,
|
176 |
+
validation_steps=validation_steps,
|
177 |
+
**kwargs)
|
178 |
+
|
179 |
+
# The fine-tuned model does not have the same input interface
|
180 |
+
# after being exported and loaded again.
|
181 |
+
self._reload_model()
|
182 |
+
|
183 |
+
def predict_one(
|
184 |
+
self,
|
185 |
+
text,
|
186 |
+
split_strategy=None,
|
187 |
+
aggregation_strategy=None
|
188 |
+
):
|
189 |
+
return next(
|
190 |
+
self.predict([text],
|
191 |
+
batch_size=1,
|
192 |
+
split_strategy=split_strategy,
|
193 |
+
aggregation_strategy=aggregation_strategy))
|
194 |
+
|
195 |
+
def predict(
|
196 |
+
self,
|
197 |
+
texts,
|
198 |
+
batch_size=32,
|
199 |
+
split_strategy=None,
|
200 |
+
aggregation_strategy=None
|
201 |
+
):
|
202 |
+
if split_strategy is None:
|
203 |
+
yield from self._predict_batch(texts, batch_size)
|
204 |
+
|
205 |
+
else:
|
206 |
+
if aggregation_strategy is None:
|
207 |
+
aggregation_strategy = AggregationStrategies.Mean
|
208 |
+
|
209 |
+
split_indexes = [0]
|
210 |
+
sentences = []
|
211 |
+
for text in texts:
|
212 |
+
new_sentences = split_strategy.split(text, self.tokenizer)
|
213 |
+
if not new_sentences:
|
214 |
+
continue
|
215 |
+
split_indexes.append(split_indexes[-1] + len(new_sentences))
|
216 |
+
sentences.extend(new_sentences)
|
217 |
+
|
218 |
+
predictions = list(self._predict_batch(sentences, batch_size))
|
219 |
+
for i, split_index in enumerate(split_indexes[:-1]):
|
220 |
+
stop_index = split_indexes[i + 1]
|
221 |
+
yield aggregation_strategy.aggregate(
|
222 |
+
predictions[split_index:stop_index]
|
223 |
+
)
|
224 |
+
|
225 |
+
def dump(self, path):
|
226 |
+
if self._model_path:
|
227 |
+
copy_dir(self._model_path, path)
|
228 |
+
else:
|
229 |
+
self._dump(path)
|
230 |
+
|
231 |
+
def _dump(self, path):
|
232 |
+
make_dir(path)
|
233 |
+
make_dir(path + '/tokenizer')
|
234 |
+
self._model.save_pretrained(path)
|
235 |
+
self._tokenizer.save_pretrained(path + '/tokenizer')
|
236 |
+
self._config.save_pretrained(path + '/tokenizer')
|
237 |
+
|
238 |
+
def _predict_batch(self, sentences: list, batch_size: int):
|
239 |
+
sentences_number = len(sentences)
|
240 |
+
if batch_size > sentences_number:
|
241 |
+
batch_size = sentences_number
|
242 |
+
|
243 |
+
for i in range(0, sentences_number, batch_size):
|
244 |
+
input_ids_list = []
|
245 |
+
attention_mask_list = []
|
246 |
+
|
247 |
+
stop_index = i + batch_size
|
248 |
+
stop_index = stop_index if stop_index < sentences_number \
|
249 |
+
else sentences_number
|
250 |
+
|
251 |
+
for j in range(i, stop_index):
|
252 |
+
features = self._tokenizer.encode_plus(
|
253 |
+
sentences[j],
|
254 |
+
add_special_tokens=True,
|
255 |
+
max_length=self._tokenizer.model_max_length
|
256 |
+
)
|
257 |
+
input_ids, _, attention_mask = (
|
258 |
+
features['input_ids'],
|
259 |
+
features['token_type_ids'],
|
260 |
+
features['attention_mask']
|
261 |
+
)
|
262 |
+
|
263 |
+
input_ids = self._list_to_padded_array(features['input_ids'])
|
264 |
+
attention_mask = self._list_to_padded_array(
|
265 |
+
features['attention_mask'])
|
266 |
+
|
267 |
+
input_ids_list.append(input_ids)
|
268 |
+
attention_mask_list.append(attention_mask)
|
269 |
+
|
270 |
+
input_dict = {
|
271 |
+
'input_ids': np.array(input_ids_list),
|
272 |
+
'attention_mask': np.array(attention_mask_list)
|
273 |
+
}
|
274 |
+
logit_predictions = self._model.predict_on_batch(input_dict)
|
275 |
+
yield from (
|
276 |
+
[softmax(logit_prediction)
|
277 |
+
for logit_prediction in logit_predictions[0]]
|
278 |
+
)
|
279 |
+
|
280 |
+
def _list_to_padded_array(self, items):
|
281 |
+
array = np.array(items)
|
282 |
+
padded_array = np.zeros(self._tokenizer.model_max_length, dtype=np.int)
|
283 |
+
padded_array[:array.shape[0]] = array
|
284 |
+
return padded_array
|
285 |
+
|
286 |
+
def _get_temporary_path(self, name=''):
|
287 |
+
return f'{AUTOSAVE_PATH}{name}/{int(round(time.time() * 1000))}'
|
288 |
+
|
289 |
+
def _reload_model(self):
|
290 |
+
self._model_path = self._get_temporary_path(
|
291 |
+
name=self._get_model_family())
|
292 |
+
self._dump(self._model_path)
|
293 |
+
self._load_local_model(self._model_path)
|
294 |
+
|
295 |
+
def _load_local_model(self, model_path):
|
296 |
+
try:
|
297 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
298 |
+
model_path + '/tokenizer')
|
299 |
+
self._config = AutoConfig.from_pretrained(
|
300 |
+
model_path + '/tokenizer')
|
301 |
+
|
302 |
+
# Old models didn't use to have a tokenizer folder
|
303 |
+
except OSError:
|
304 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
305 |
+
self._config = AutoConfig.from_pretrained(model_path)
|
306 |
+
self._model = TFAutoModelForSequenceClassification.from_pretrained(
|
307 |
+
model_path,
|
308 |
+
from_pt=False
|
309 |
+
)
|
310 |
+
|
311 |
+
def _get_model_family(self):
|
312 |
+
model_family = ''.join(self._model.name[2:].split('_')[:2])
|
313 |
+
return model_family
|
314 |
+
|
315 |
+
def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs):
|
316 |
+
do_lower_case = False
|
317 |
+
if 'uncased' in model_name.lower():
|
318 |
+
do_lower_case = True
|
319 |
+
tokenizer_kwargs.update({'do_lower_case': do_lower_case})
|
320 |
+
|
321 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
322 |
+
model_name, **tokenizer_kwargs)
|
323 |
+
self._config = AutoConfig.from_pretrained(model_name)
|
324 |
+
|
325 |
+
temporary_path = self._get_temporary_path()
|
326 |
+
make_dir(temporary_path)
|
327 |
+
|
328 |
+
# TensorFlow model
|
329 |
+
try:
|
330 |
+
self._model = TFAutoModelForSequenceClassification.from_pretrained(
|
331 |
+
model_name,
|
332 |
+
from_pt=False
|
333 |
+
)
|
334 |
+
|
335 |
+
# PyTorch model
|
336 |
+
except TypeError:
|
337 |
+
try:
|
338 |
+
self._model = \
|
339 |
+
TFAutoModelForSequenceClassification.from_pretrained(
|
340 |
+
model_name,
|
341 |
+
from_pt=True
|
342 |
+
)
|
343 |
+
|
344 |
+
# Loading a TF model from a PyTorch checkpoint is not supported
|
345 |
+
# when using a model identifier name
|
346 |
+
except OSError:
|
347 |
+
model = AutoModel.from_pretrained(model_name)
|
348 |
+
model.save_pretrained(temporary_path)
|
349 |
+
self._model = \
|
350 |
+
TFAutoModelForSequenceClassification.from_pretrained(
|
351 |
+
temporary_path,
|
352 |
+
from_pt=True
|
353 |
+
)
|
354 |
+
|
355 |
+
# Clean the model's last layer if the provided properties are different
|
356 |
+
clean_last_layer = False
|
357 |
+
for key, value in model_kwargs.items():
|
358 |
+
if not hasattr(self._model.config, key):
|
359 |
+
clean_last_layer = True
|
360 |
+
break
|
361 |
+
|
362 |
+
if getattr(self._model.config, key) != value:
|
363 |
+
clean_last_layer = True
|
364 |
+
break
|
365 |
+
|
366 |
+
if clean_last_layer:
|
367 |
+
try:
|
368 |
+
getattr(self._model, self._get_model_family()
|
369 |
+
).save_pretrained(temporary_path)
|
370 |
+
self._model = self._model.__class__.from_pretrained(
|
371 |
+
temporary_path,
|
372 |
+
from_pt=False,
|
373 |
+
**model_kwargs
|
374 |
+
)
|
375 |
+
|
376 |
+
# The model is itself the main layer
|
377 |
+
except AttributeError:
|
378 |
+
# TensorFlow model
|
379 |
+
try:
|
380 |
+
self._model = self._model.__class__.from_pretrained(
|
381 |
+
model_name,
|
382 |
+
from_pt=False,
|
383 |
+
**model_kwargs
|
384 |
+
)
|
385 |
+
|
386 |
+
# PyTorch Model
|
387 |
+
except (OSError, TypeError):
|
388 |
+
model = AutoModel.from_pretrained(model_name)
|
389 |
+
model.save_pretrained(temporary_path)
|
390 |
+
self._model = self._model.__class__.from_pretrained(
|
391 |
+
temporary_path,
|
392 |
+
from_pt=True,
|
393 |
+
**model_kwargs
|
394 |
+
)
|
395 |
+
|
396 |
+
remove_dir(temporary_path)
|
397 |
+
assert self._tokenizer and self._model
|
ernie/helper.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from tensorflow import data, TensorShape, int64, int32
|
5 |
+
from math import exp
|
6 |
+
from os import makedirs
|
7 |
+
from shutil import rmtree, move, copytree
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
def get_features(tokenizer, sentences, labels):
|
13 |
+
features = []
|
14 |
+
for i, sentence in enumerate(sentences):
|
15 |
+
inputs = tokenizer.encode_plus(
|
16 |
+
sentence,
|
17 |
+
add_special_tokens=True,
|
18 |
+
max_length=tokenizer.model_max_length
|
19 |
+
)
|
20 |
+
input_ids, token_type_ids = \
|
21 |
+
inputs['input_ids'], inputs['token_type_ids']
|
22 |
+
padding_length = tokenizer.model_max_length - len(input_ids)
|
23 |
+
|
24 |
+
if tokenizer.padding_side == 'right':
|
25 |
+
attention_mask = [1] * len(input_ids) + [0] * padding_length
|
26 |
+
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
|
27 |
+
token_type_ids = token_type_ids + \
|
28 |
+
[tokenizer.pad_token_type_id] * padding_length
|
29 |
+
else:
|
30 |
+
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
31 |
+
input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
|
32 |
+
token_type_ids = \
|
33 |
+
[tokenizer.pad_token_type_id] * padding_length + token_type_ids
|
34 |
+
|
35 |
+
assert tokenizer.model_max_length \
|
36 |
+
== len(attention_mask) \
|
37 |
+
== len(input_ids) \
|
38 |
+
== len(token_type_ids)
|
39 |
+
|
40 |
+
feature = {
|
41 |
+
'input_ids': input_ids,
|
42 |
+
'attention_mask': attention_mask,
|
43 |
+
'token_type_ids': token_type_ids,
|
44 |
+
'label': int(labels[i])
|
45 |
+
}
|
46 |
+
|
47 |
+
features.append(feature)
|
48 |
+
|
49 |
+
def gen():
|
50 |
+
for feature in features:
|
51 |
+
yield (
|
52 |
+
{
|
53 |
+
'input_ids': feature['input_ids'],
|
54 |
+
'attention_mask': feature['attention_mask'],
|
55 |
+
'token_type_ids': feature['token_type_ids'],
|
56 |
+
},
|
57 |
+
feature['label'],
|
58 |
+
)
|
59 |
+
|
60 |
+
dataset = data.Dataset.from_generator(
|
61 |
+
gen,
|
62 |
+
({
|
63 |
+
'input_ids': int32,
|
64 |
+
'attention_mask': int32,
|
65 |
+
'token_type_ids': int32
|
66 |
+
}, int64),
|
67 |
+
(
|
68 |
+
{
|
69 |
+
'input_ids': TensorShape([None]),
|
70 |
+
'attention_mask': TensorShape([None]),
|
71 |
+
'token_type_ids': TensorShape([None]),
|
72 |
+
},
|
73 |
+
TensorShape([]),
|
74 |
+
),
|
75 |
+
)
|
76 |
+
|
77 |
+
return dataset
|
78 |
+
|
79 |
+
|
80 |
+
def softmax(values):
|
81 |
+
exps = [exp(value) for value in values]
|
82 |
+
exps_sum = sum(exp_value for exp_value in exps)
|
83 |
+
return tuple(map(lambda x: x / exps_sum, exps))
|
84 |
+
|
85 |
+
|
86 |
+
def make_dir(path):
|
87 |
+
try:
|
88 |
+
makedirs(path)
|
89 |
+
except FileExistsError:
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
def remove_dir(path):
|
94 |
+
rmtree(path)
|
95 |
+
|
96 |
+
|
97 |
+
def copy_dir(source_path, target_path):
|
98 |
+
copytree(source_path, target_path)
|
99 |
+
|
100 |
+
|
101 |
+
def move_dir(source_path, target_path):
|
102 |
+
move(source_path, target_path)
|
103 |
+
|
104 |
+
def download_from_hub(repo_id, filename, revision=None, cache_dir=None):
|
105 |
+
try:
|
106 |
+
hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir)
|
107 |
+
except Exception as exp:
|
108 |
+
raise exp
|
109 |
+
|
110 |
+
|
111 |
+
if cache_dir is not None:
|
112 |
+
|
113 |
+
files = os.listdir(cache_dir)
|
114 |
+
|
115 |
+
for f in files:
|
116 |
+
if '.lock' in f:
|
117 |
+
name = f[0:-5]
|
118 |
+
|
119 |
+
os.rename(cache_dir+name, cache_dir+filename)
|
120 |
+
os.remove(cache_dir+name+'.lock')
|
121 |
+
os.remove(cache_dir+name+'.json')
|
ernie/models.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
class Models:
|
6 |
+
BertBaseUncased = 'bert-base-uncased'
|
7 |
+
BertBaseCased = 'bert-base-cased'
|
8 |
+
BertLargeUncased = 'bert-large-uncased'
|
9 |
+
BertLargeCased = 'bert-large-cased'
|
10 |
+
|
11 |
+
RobertaBaseCased = 'roberta-base'
|
12 |
+
RobertaLargeCased = 'roberta-large'
|
13 |
+
|
14 |
+
XLNetBaseCased = 'xlnet-base-cased'
|
15 |
+
XLNetLargeCased = 'xlnet-large-cased'
|
16 |
+
|
17 |
+
DistilBertBaseUncased = 'distilbert-base-uncased'
|
18 |
+
DistilBertBaseMultilingualCased = 'distilbert-base-multilingual-cased'
|
19 |
+
|
20 |
+
AlbertBaseCased = 'albert-base-v1'
|
21 |
+
AlbertLargeCased = 'albert-large-v1'
|
22 |
+
AlbertXLargeCased = 'albert-xlarge-v1'
|
23 |
+
AlbertXXLargeCased = 'albert-xxlarge-v1'
|
24 |
+
|
25 |
+
AlbertBaseCased2 = 'albert-base-v2'
|
26 |
+
AlbertLargeCased2 = 'albert-large-v2'
|
27 |
+
AlbertXLargeCased2 = 'albert-xlarge-v2'
|
28 |
+
AlbertXXLargeCased2 = 'albert-xxlarge-v2'
|
29 |
+
|
30 |
+
|
31 |
+
class ModelsByFamily:
|
32 |
+
Bert = set([Models.BertBaseUncased, Models.BertBaseCased,
|
33 |
+
Models.BertLargeUncased, Models.BertLargeCased])
|
34 |
+
Roberta = set([Models.RobertaBaseCased, Models.RobertaLargeCased])
|
35 |
+
XLNet = set([Models.XLNetBaseCased, Models.XLNetLargeCased])
|
36 |
+
DistilBert = set([Models.DistilBertBaseUncased,
|
37 |
+
Models.DistilBertBaseMultilingualCased])
|
38 |
+
Albert = set([
|
39 |
+
Models.AlbertBaseCased,
|
40 |
+
Models.AlbertLargeCased,
|
41 |
+
Models.AlbertXLargeCased,
|
42 |
+
Models.AlbertXXLargeCased,
|
43 |
+
Models.AlbertBaseCased2,
|
44 |
+
Models.AlbertLargeCased2,
|
45 |
+
Models.AlbertXLargeCased2,
|
46 |
+
Models.AlbertXXLargeCased2
|
47 |
+
])
|
48 |
+
Supported = set([
|
49 |
+
getattr(Models, model_type) for model_type
|
50 |
+
in filter(lambda x: x[:2] != '__', Models.__dict__.keys())
|
51 |
+
])
|
ernie/split_strategies.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import re
|
5 |
+
|
6 |
+
|
7 |
+
class RegexExpressions:
|
8 |
+
split_by_dot = re.compile(r'[^.]+(?:\.\s*)?')
|
9 |
+
split_by_semicolon = re.compile(r'[^;]+(?:\;\s*)?')
|
10 |
+
split_by_colon = re.compile(r'[^:]+(?:\:\s*)?')
|
11 |
+
split_by_comma = re.compile(r'[^,]+(?:\,\s*)?')
|
12 |
+
|
13 |
+
url = re.compile(
|
14 |
+
r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
|
15 |
+
r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
|
16 |
+
)
|
17 |
+
domain = re.compile(r'\w+\.\w+')
|
18 |
+
|
19 |
+
|
20 |
+
class SplitStrategy:
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
split_patterns,
|
24 |
+
remove_patterns=None,
|
25 |
+
group_splits=True,
|
26 |
+
remove_too_short_groups=True
|
27 |
+
):
|
28 |
+
if not isinstance(split_patterns, list):
|
29 |
+
self.split_patterns = [split_patterns]
|
30 |
+
else:
|
31 |
+
self.split_patterns = split_patterns
|
32 |
+
|
33 |
+
if remove_patterns is not None \
|
34 |
+
and not isinstance(remove_patterns, list):
|
35 |
+
self.remove_patterns = [remove_patterns]
|
36 |
+
else:
|
37 |
+
self.remove_patterns = remove_patterns
|
38 |
+
|
39 |
+
self.group_splits = group_splits
|
40 |
+
self.remove_too_short_groups = remove_too_short_groups
|
41 |
+
|
42 |
+
def split(self, text, tokenizer, split_patterns=None):
|
43 |
+
if split_patterns is None:
|
44 |
+
if self.split_patterns is None:
|
45 |
+
return [text]
|
46 |
+
split_patterns = self.split_patterns
|
47 |
+
|
48 |
+
def len_in_tokens(text_):
|
49 |
+
no_tokens = len(tokenizer.encode(text_, add_special_tokens=False))
|
50 |
+
return no_tokens
|
51 |
+
|
52 |
+
no_special_tokens = len(tokenizer.encode('', add_special_tokens=True))
|
53 |
+
max_tokens = tokenizer.max_len - no_special_tokens
|
54 |
+
|
55 |
+
if self.remove_patterns is not None:
|
56 |
+
for remove_pattern in self.remove_patterns:
|
57 |
+
text = re.sub(remove_pattern, '', text).strip()
|
58 |
+
|
59 |
+
if len_in_tokens(text) <= max_tokens:
|
60 |
+
return [text]
|
61 |
+
|
62 |
+
selected_splits = []
|
63 |
+
splits = map(lambda x: x.strip(), re.findall(split_patterns[0], text))
|
64 |
+
|
65 |
+
aggregated_splits = ''
|
66 |
+
for split in splits:
|
67 |
+
if len_in_tokens(split) > max_tokens:
|
68 |
+
if len(split_patterns) > 1:
|
69 |
+
sub_splits = self.split(
|
70 |
+
split, tokenizer, split_patterns[1:])
|
71 |
+
selected_splits.extend(sub_splits)
|
72 |
+
else:
|
73 |
+
selected_splits.append(split)
|
74 |
+
|
75 |
+
else:
|
76 |
+
if not self.group_splits:
|
77 |
+
selected_splits.append(split)
|
78 |
+
else:
|
79 |
+
new_aggregated_splits = \
|
80 |
+
f'{aggregated_splits} {split}'.strip()
|
81 |
+
if len_in_tokens(new_aggregated_splits) <= max_tokens:
|
82 |
+
aggregated_splits = new_aggregated_splits
|
83 |
+
else:
|
84 |
+
selected_splits.append(aggregated_splits)
|
85 |
+
aggregated_splits = split
|
86 |
+
|
87 |
+
if aggregated_splits:
|
88 |
+
selected_splits.append(aggregated_splits)
|
89 |
+
|
90 |
+
remove_too_short_groups = len(selected_splits) > 1 \
|
91 |
+
and self.group_splits \
|
92 |
+
and self.remove_too_short_groups
|
93 |
+
|
94 |
+
if not remove_too_short_groups:
|
95 |
+
final_splits = selected_splits
|
96 |
+
else:
|
97 |
+
final_splits = []
|
98 |
+
min_length = tokenizer.max_len / 2
|
99 |
+
for split in selected_splits:
|
100 |
+
if len_in_tokens(split) >= min_length:
|
101 |
+
final_splits.append(split)
|
102 |
+
|
103 |
+
return final_splits
|
104 |
+
|
105 |
+
|
106 |
+
class SplitStrategies:
|
107 |
+
SentencesWithoutUrls = SplitStrategy(split_patterns=[
|
108 |
+
RegexExpressions.split_by_dot,
|
109 |
+
RegexExpressions.split_by_semicolon,
|
110 |
+
RegexExpressions.split_by_colon,
|
111 |
+
RegexExpressions.split_by_comma
|
112 |
+
],
|
113 |
+
remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
|
114 |
+
remove_too_short_groups=False,
|
115 |
+
group_splits=False)
|
116 |
+
|
117 |
+
GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[
|
118 |
+
RegexExpressions.split_by_dot,
|
119 |
+
RegexExpressions.split_by_semicolon,
|
120 |
+
RegexExpressions.split_by_colon,
|
121 |
+
RegexExpressions.split_by_comma
|
122 |
+
],
|
123 |
+
remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
|
124 |
+
remove_too_short_groups=True,
|
125 |
+
group_splits=True)
|