|
"""Wrapper of AllenNLP model. Fixes errors based on model predictions""" |
|
import logging |
|
import os |
|
import sys |
|
from time import time |
|
|
|
import torch |
|
from allennlp.data.dataset import Batch |
|
from allennlp.data.fields import TextField |
|
from allennlp.data.instance import Instance |
|
from allennlp.data.tokenizers import Token |
|
from allennlp.data.vocabulary import Vocabulary |
|
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder |
|
from allennlp.nn import util |
|
|
|
from gector.bert_token_embedder import PretrainedBertEmbedder |
|
from gector.seq2labels_model import Seq2Labels |
|
from gector.tokenizer_indexer import PretrainedBertIndexer |
|
from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN |
|
from utils.helpers import get_weights_name |
|
|
|
logging.getLogger("werkzeug").setLevel(logging.ERROR) |
|
logger = logging.getLogger(__file__) |
|
|
|
|
|
class GecBERTModel(object): |
|
def __init__(self, vocab_path=None, model_paths=None, |
|
weigths=None, |
|
max_len=50, |
|
min_len=3, |
|
lowercase_tokens=False, |
|
log=False, |
|
iterations=3, |
|
model_name='roberta', |
|
special_tokens_fix=1, |
|
is_ensemble=True, |
|
min_error_probability=0.0, |
|
confidence=0, |
|
del_confidence=0, |
|
resolve_cycles=False, |
|
): |
|
self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths) |
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
self.max_len = max_len |
|
self.min_len = min_len |
|
self.lowercase_tokens = lowercase_tokens |
|
self.min_error_probability = min_error_probability |
|
self.vocab = Vocabulary.from_files(vocab_path) |
|
self.log = log |
|
self.iterations = iterations |
|
self.confidence = confidence |
|
self.del_conf = del_confidence |
|
self.resolve_cycles = resolve_cycles |
|
|
|
|
|
self.indexers = [] |
|
self.models = [] |
|
for model_path in model_paths: |
|
if is_ensemble: |
|
model_name, special_tokens_fix = self._get_model_data(model_path) |
|
weights_name = get_weights_name(model_name, lowercase_tokens) |
|
self.indexers.append(self._get_indexer(weights_name, special_tokens_fix)) |
|
model = Seq2Labels(vocab=self.vocab, |
|
text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix), |
|
confidence=self.confidence, |
|
del_confidence=self.del_conf, |
|
).to(self.device) |
|
if torch.cuda.is_available(): |
|
model.load_state_dict(torch.load(model_path), strict=False) |
|
else: |
|
model.load_state_dict(torch.load(model_path, |
|
map_location=torch.device('cpu')), |
|
strict=False) |
|
model.eval() |
|
self.models.append(model) |
|
|
|
@staticmethod |
|
def _get_model_data(model_path): |
|
model_name = model_path.split('/')[-1] |
|
tr_model, stf = model_name.split('_')[:2] |
|
return tr_model, int(stf) |
|
|
|
def _restore_model(self, input_path): |
|
if os.path.isdir(input_path): |
|
print("Model could not be restored from directory", file=sys.stderr) |
|
filenames = [] |
|
else: |
|
filenames = [input_path] |
|
for model_path in filenames: |
|
try: |
|
if torch.cuda.is_available(): |
|
loaded_model = torch.load(model_path) |
|
else: |
|
loaded_model = torch.load(model_path, |
|
map_location=lambda storage, |
|
loc: storage) |
|
except: |
|
print(f"{model_path} is not valid model", file=sys.stderr) |
|
own_state = self.model.state_dict() |
|
for name, weights in loaded_model.items(): |
|
if name not in own_state: |
|
continue |
|
try: |
|
if len(filenames) == 1: |
|
own_state[name].copy_(weights) |
|
else: |
|
own_state[name] += weights |
|
except RuntimeError: |
|
continue |
|
print("Model is restored", file=sys.stderr) |
|
|
|
def predict(self, batches): |
|
t11 = time() |
|
predictions = [] |
|
for batch, model in zip(batches, self.models): |
|
batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1) |
|
with torch.no_grad(): |
|
prediction = model.forward(**batch) |
|
predictions.append(prediction) |
|
|
|
preds, idx, error_probs = self._convert(predictions) |
|
t55 = time() |
|
if self.log: |
|
print(f"Inference time {t55 - t11}") |
|
return preds, idx, error_probs |
|
|
|
def get_token_action(self, token, index, prob, sugg_token): |
|
"""Get lost of suggested actions for token.""" |
|
|
|
if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']: |
|
return None |
|
|
|
if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': |
|
start_pos = index |
|
end_pos = index + 1 |
|
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): |
|
start_pos = index + 1 |
|
end_pos = index + 1 |
|
|
|
if sugg_token == "$DELETE": |
|
sugg_token_clear = "" |
|
elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): |
|
sugg_token_clear = sugg_token[:] |
|
else: |
|
sugg_token_clear = sugg_token[sugg_token.index('_') + 1:] |
|
|
|
return start_pos - 1, end_pos - 1, sugg_token_clear, prob |
|
|
|
def _get_embbeder(self, weigths_name, special_tokens_fix): |
|
embedders = {'bert': PretrainedBertEmbedder( |
|
pretrained_model=weigths_name, |
|
requires_grad=False, |
|
top_layer_only=True, |
|
special_tokens_fix=special_tokens_fix) |
|
} |
|
text_field_embedder = BasicTextFieldEmbedder( |
|
token_embedders=embedders, |
|
embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]}, |
|
allow_unmatched_keys=True) |
|
return text_field_embedder |
|
|
|
def _get_indexer(self, weights_name, special_tokens_fix): |
|
bert_token_indexer = PretrainedBertIndexer( |
|
pretrained_model=weights_name, |
|
do_lowercase=self.lowercase_tokens, |
|
max_pieces_per_token=5, |
|
special_tokens_fix=special_tokens_fix |
|
) |
|
return {'bert': bert_token_indexer} |
|
|
|
def preprocess(self, token_batch): |
|
seq_lens = [len(sequence) for sequence in token_batch if sequence] |
|
if not seq_lens: |
|
return [] |
|
max_len = min(max(seq_lens), self.max_len) |
|
batches = [] |
|
for indexer in self.indexers: |
|
batch = [] |
|
for sequence in token_batch: |
|
tokens = sequence[:max_len] |
|
tokens = [Token(token) for token in ['$START'] + tokens] |
|
batch.append(Instance({'tokens': TextField(tokens, indexer)})) |
|
batch = Batch(batch) |
|
batch.index_instances(self.vocab) |
|
batches.append(batch) |
|
|
|
return batches |
|
|
|
def _convert(self, data): |
|
all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels']) |
|
error_probs = torch.zeros_like(data[0]['max_error_probability']) |
|
for output, weight in zip(data, self.model_weights): |
|
all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights) |
|
error_probs += weight * output['max_error_probability'] / sum(self.model_weights) |
|
|
|
max_vals = torch.max(all_class_probs, dim=-1) |
|
probs = max_vals[0].tolist() |
|
idx = max_vals[1].tolist() |
|
return probs, idx, error_probs.tolist() |
|
|
|
def update_final_batch(self, final_batch, pred_ids, pred_batch, |
|
prev_preds_dict): |
|
new_pred_ids = [] |
|
total_updated = 0 |
|
for i, orig_id in enumerate(pred_ids): |
|
orig = final_batch[orig_id] |
|
pred = pred_batch[i] |
|
prev_preds = prev_preds_dict[orig_id] |
|
if orig != pred and pred not in prev_preds: |
|
final_batch[orig_id] = pred |
|
new_pred_ids.append(orig_id) |
|
prev_preds_dict[orig_id].append(pred) |
|
total_updated += 1 |
|
elif orig != pred and pred in prev_preds: |
|
|
|
final_batch[orig_id] = pred |
|
total_updated += 1 |
|
else: |
|
continue |
|
return final_batch, new_pred_ids, total_updated |
|
|
|
def postprocess_batch(self, batch, all_probabilities, all_idxs, |
|
error_probs): |
|
all_results = [] |
|
noop_index = self.vocab.get_token_index("$KEEP", "labels") |
|
for tokens, probabilities, idxs, error_prob in zip(batch, |
|
all_probabilities, |
|
all_idxs, |
|
error_probs): |
|
length = min(len(tokens), self.max_len) |
|
edits = [] |
|
|
|
|
|
if max(idxs) == 0: |
|
all_results.append(tokens) |
|
continue |
|
|
|
|
|
if error_prob < self.min_error_probability: |
|
all_results.append(tokens) |
|
continue |
|
|
|
for i in range(length + 1): |
|
|
|
if i == 0: |
|
token = START_TOKEN |
|
else: |
|
token = tokens[i - 1] |
|
|
|
if idxs[i] == noop_index: |
|
continue |
|
|
|
sugg_token = self.vocab.get_token_from_index(idxs[i], |
|
namespace='labels') |
|
action = self.get_token_action(token, i, probabilities[i], |
|
sugg_token) |
|
if not action: |
|
continue |
|
|
|
edits.append(action) |
|
all_results.append(get_target_sent_by_edits(tokens, edits)) |
|
return all_results |
|
|
|
def handle_batch(self, full_batch): |
|
""" |
|
Handle batch of requests. |
|
""" |
|
final_batch = full_batch[:] |
|
batch_size = len(full_batch) |
|
prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} |
|
short_ids = [i for i in range(len(full_batch)) |
|
if len(full_batch[i]) < self.min_len] |
|
pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] |
|
total_updates = 0 |
|
|
|
for n_iter in range(self.iterations): |
|
orig_batch = [final_batch[i] for i in pred_ids] |
|
|
|
sequences = self.preprocess(orig_batch) |
|
|
|
if not sequences: |
|
break |
|
probabilities, idxs, error_probs = self.predict(sequences) |
|
|
|
pred_batch = self.postprocess_batch(orig_batch, probabilities, |
|
idxs, error_probs) |
|
if self.log: |
|
print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") |
|
|
|
final_batch, pred_ids, cnt = \ |
|
self.update_final_batch(final_batch, pred_ids, pred_batch, |
|
prev_preds_dict) |
|
total_updates += cnt |
|
|
|
if not pred_ids: |
|
break |
|
|
|
return final_batch, total_updates |
|
|