Spaces:
Running
Running
| """Wrapper of AllenNLP model. Fixes errors based on model predictions""" | |
| import logging | |
| import os | |
| import sys | |
| from time import time | |
| import torch | |
| from allennlp.data.dataset import Batch | |
| from allennlp.data.fields import TextField | |
| from allennlp.data.instance import Instance | |
| from allennlp.data.tokenizers import Token | |
| from allennlp.data.vocabulary import Vocabulary | |
| from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder | |
| from allennlp.nn import util | |
| from gector.bert_token_embedder import PretrainedBertEmbedder | |
| from gector.seq2labels_model import Seq2Labels | |
| from gector.tokenizer_indexer import PretrainedBertIndexer | |
| from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN | |
| from utils.helpers import get_weights_name | |
| logging.getLogger("werkzeug").setLevel(logging.ERROR) | |
| logger = logging.getLogger(__file__) | |
| class GecBERTModel(object): | |
| def __init__(self, vocab_path=None, model_paths=None, | |
| weigths=None, | |
| max_len=50, | |
| min_len=3, | |
| lowercase_tokens=False, | |
| log=False, | |
| iterations=3, | |
| model_name='roberta', | |
| special_tokens_fix=1, | |
| is_ensemble=True, | |
| min_error_probability=0.0, | |
| confidence=0, | |
| del_confidence=0, | |
| resolve_cycles=False, | |
| ): | |
| self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths) | |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| self.max_len = max_len | |
| self.min_len = min_len | |
| self.lowercase_tokens = lowercase_tokens | |
| self.min_error_probability = min_error_probability | |
| self.vocab = Vocabulary.from_files(vocab_path) | |
| self.log = log | |
| self.iterations = iterations | |
| self.confidence = confidence | |
| self.del_conf = del_confidence | |
| self.resolve_cycles = resolve_cycles | |
| # set training parameters and operations | |
| self.indexers = [] | |
| self.models = [] | |
| for model_path in model_paths: | |
| if is_ensemble: | |
| model_name, special_tokens_fix = self._get_model_data(model_path) | |
| weights_name = get_weights_name(model_name, lowercase_tokens) | |
| self.indexers.append(self._get_indexer(weights_name, special_tokens_fix)) | |
| model = Seq2Labels(vocab=self.vocab, | |
| text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix), | |
| confidence=self.confidence, | |
| del_confidence=self.del_conf, | |
| ).to(self.device) | |
| if torch.cuda.is_available(): | |
| model.load_state_dict(torch.load(model_path), strict=False) | |
| else: | |
| model.load_state_dict(torch.load(model_path, | |
| map_location=torch.device('cpu')), | |
| strict=False) | |
| model.eval() | |
| self.models.append(model) | |
| def _get_model_data(model_path): | |
| model_name = model_path.split('/')[-1] | |
| tr_model, stf = model_name.split('_')[:2] | |
| return tr_model, int(stf) | |
| def _restore_model(self, input_path): | |
| if os.path.isdir(input_path): | |
| print("Model could not be restored from directory", file=sys.stderr) | |
| filenames = [] | |
| else: | |
| filenames = [input_path] | |
| for model_path in filenames: | |
| try: | |
| if torch.cuda.is_available(): | |
| loaded_model = torch.load(model_path) | |
| else: | |
| loaded_model = torch.load(model_path, | |
| map_location=lambda storage, | |
| loc: storage) | |
| except: | |
| print(f"{model_path} is not valid model", file=sys.stderr) | |
| own_state = self.model.state_dict() | |
| for name, weights in loaded_model.items(): | |
| if name not in own_state: | |
| continue | |
| try: | |
| if len(filenames) == 1: | |
| own_state[name].copy_(weights) | |
| else: | |
| own_state[name] += weights | |
| except RuntimeError: | |
| continue | |
| print("Model is restored", file=sys.stderr) | |
| def predict(self, batches): | |
| t11 = time() | |
| predictions = [] | |
| for batch, model in zip(batches, self.models): | |
| batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1) | |
| with torch.no_grad(): | |
| prediction = model.forward(**batch) | |
| predictions.append(prediction) | |
| preds, idx, error_probs = self._convert(predictions) | |
| t55 = time() | |
| if self.log: | |
| print(f"Inference time {t55 - t11}") | |
| return preds, idx, error_probs | |
| def get_token_action(self, token, index, prob, sugg_token): | |
| """Get lost of suggested actions for token.""" | |
| # cases when we don't need to do anything | |
| if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']: | |
| return None | |
| if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': | |
| start_pos = index | |
| end_pos = index + 1 | |
| elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): | |
| start_pos = index + 1 | |
| end_pos = index + 1 | |
| if sugg_token == "$DELETE": | |
| sugg_token_clear = "" | |
| elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): | |
| sugg_token_clear = sugg_token[:] | |
| else: | |
| sugg_token_clear = sugg_token[sugg_token.index('_') + 1:] | |
| return start_pos - 1, end_pos - 1, sugg_token_clear, prob | |
| def _get_embbeder(self, weigths_name, special_tokens_fix): | |
| embedders = {'bert': PretrainedBertEmbedder( | |
| pretrained_model=weigths_name, | |
| requires_grad=False, | |
| top_layer_only=True, | |
| special_tokens_fix=special_tokens_fix) | |
| } | |
| text_field_embedder = BasicTextFieldEmbedder( | |
| token_embedders=embedders, | |
| embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]}, | |
| allow_unmatched_keys=True) | |
| return text_field_embedder | |
| def _get_indexer(self, weights_name, special_tokens_fix): | |
| bert_token_indexer = PretrainedBertIndexer( | |
| pretrained_model=weights_name, | |
| do_lowercase=self.lowercase_tokens, | |
| max_pieces_per_token=5, | |
| special_tokens_fix=special_tokens_fix | |
| ) | |
| return {'bert': bert_token_indexer} | |
| def preprocess(self, token_batch): | |
| seq_lens = [len(sequence) for sequence in token_batch if sequence] | |
| if not seq_lens: | |
| return [] | |
| max_len = min(max(seq_lens), self.max_len) | |
| batches = [] | |
| for indexer in self.indexers: | |
| batch = [] | |
| for sequence in token_batch: | |
| tokens = sequence[:max_len] | |
| tokens = [Token(token) for token in ['$START'] + tokens] | |
| batch.append(Instance({'tokens': TextField(tokens, indexer)})) | |
| batch = Batch(batch) | |
| batch.index_instances(self.vocab) | |
| batches.append(batch) | |
| return batches | |
| def _convert(self, data): | |
| all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels']) | |
| error_probs = torch.zeros_like(data[0]['max_error_probability']) | |
| for output, weight in zip(data, self.model_weights): | |
| all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights) | |
| error_probs += weight * output['max_error_probability'] / sum(self.model_weights) | |
| max_vals = torch.max(all_class_probs, dim=-1) | |
| probs = max_vals[0].tolist() | |
| idx = max_vals[1].tolist() | |
| return probs, idx, error_probs.tolist() | |
| def update_final_batch(self, final_batch, pred_ids, pred_batch, | |
| prev_preds_dict): | |
| new_pred_ids = [] | |
| total_updated = 0 | |
| for i, orig_id in enumerate(pred_ids): | |
| orig = final_batch[orig_id] | |
| pred = pred_batch[i] | |
| prev_preds = prev_preds_dict[orig_id] | |
| if orig != pred and pred not in prev_preds: | |
| final_batch[orig_id] = pred | |
| new_pred_ids.append(orig_id) | |
| prev_preds_dict[orig_id].append(pred) | |
| total_updated += 1 | |
| elif orig != pred and pred in prev_preds: | |
| # update final batch, but stop iterations | |
| final_batch[orig_id] = pred | |
| total_updated += 1 | |
| else: | |
| continue | |
| return final_batch, new_pred_ids, total_updated | |
| def postprocess_batch(self, batch, all_probabilities, all_idxs, | |
| error_probs): | |
| all_results = [] | |
| noop_index = self.vocab.get_token_index("$KEEP", "labels") | |
| for tokens, probabilities, idxs, error_prob in zip(batch, | |
| all_probabilities, | |
| all_idxs, | |
| error_probs): | |
| length = min(len(tokens), self.max_len) | |
| edits = [] | |
| # skip whole sentences if there no errors | |
| if max(idxs) == 0: | |
| all_results.append(tokens) | |
| continue | |
| # skip whole sentence if probability of correctness is not high | |
| if error_prob < self.min_error_probability: | |
| all_results.append(tokens) | |
| continue | |
| for i in range(length + 1): | |
| # because of START token | |
| if i == 0: | |
| token = START_TOKEN | |
| else: | |
| token = tokens[i - 1] | |
| # skip if there is no error | |
| if idxs[i] == noop_index: | |
| continue | |
| sugg_token = self.vocab.get_token_from_index(idxs[i], | |
| namespace='labels') | |
| action = self.get_token_action(token, i, probabilities[i], | |
| sugg_token) | |
| if not action: | |
| continue | |
| edits.append(action) | |
| all_results.append(get_target_sent_by_edits(tokens, edits)) | |
| return all_results | |
| def handle_batch(self, full_batch): | |
| """ | |
| Handle batch of requests. | |
| """ | |
| final_batch = full_batch[:] | |
| batch_size = len(full_batch) | |
| prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} | |
| short_ids = [i for i in range(len(full_batch)) | |
| if len(full_batch[i]) < self.min_len] | |
| pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] | |
| total_updates = 0 | |
| for n_iter in range(self.iterations): | |
| orig_batch = [final_batch[i] for i in pred_ids] | |
| sequences = self.preprocess(orig_batch) | |
| if not sequences: | |
| break | |
| probabilities, idxs, error_probs = self.predict(sequences) | |
| pred_batch = self.postprocess_batch(orig_batch, probabilities, | |
| idxs, error_probs) | |
| if self.log: | |
| print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") | |
| final_batch, pred_ids, cnt = \ | |
| self.update_final_batch(final_batch, pred_ids, pred_batch, | |
| prev_preds_dict) | |
| total_updates += cnt | |
| if not pred_ids: | |
| break | |
| return final_batch, total_updates | |