|
import re |
|
import torch |
|
from string import printable, punctuation |
|
from tqdm import tqdm |
|
import warnings |
|
|
|
|
|
class Normalizer: |
|
def __init__(self, |
|
device='cpu', |
|
jit_model='jit_s2s.pt'): |
|
super(Normalizer, self).__init__() |
|
|
|
self.device = torch.device(device) |
|
|
|
self.init_vocabs() |
|
|
|
self.model = torch.jit.load(jit_model, map_location=device) |
|
self.model.eval() |
|
self.max_len = 150 |
|
|
|
def init_vocabs(self): |
|
|
|
|
|
|
|
rus_letters = 'абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ' |
|
spec_symbols = '¼³№¾⅞½⅔⅓⅛⅜²' |
|
|
|
self.src_vocab = {token: i + 5 for i, token in enumerate(printable[:-5] + rus_letters + '«»—' + spec_symbols)} |
|
|
|
self.tgt_vocab = {token: i + 5 for i, token in enumerate(punctuation + rus_letters + ' ' + '«»—')} |
|
|
|
unk = '#UNK#' |
|
pad = '#PAD#' |
|
sos = '#SOS#' |
|
eos = '#EOS#' |
|
tfo = '#TFO#' |
|
for i, token in enumerate([unk, pad, sos, eos, tfo]): |
|
self.src_vocab[token] = i |
|
self.tgt_vocab[token] = i |
|
|
|
for i, token_name in enumerate(['unk', 'pad', 'sos', 'eos', 'tfo']): |
|
setattr(self, '{}_index'.format(token_name), i) |
|
|
|
inv_src_vocab = {v: k for k, v in self.src_vocab.items()} |
|
self.src2tgt = {src_i: self.tgt_vocab.get(src_symb, -1) for src_i, src_symb in inv_src_vocab.items()} |
|
|
|
def keep_unknown(self, string): |
|
reg = re.compile(r'[^{}]+'.format(''.join(self.src_vocab.keys()))) |
|
unk_list = re.findall(reg, string) |
|
|
|
unk_ids = [range(m.start() + 1, m.end()) for m in re.finditer(reg, string) if m.end() - m.start() > 1] |
|
flat_unk_ids = [i for sublist in unk_ids for i in sublist] |
|
|
|
upd_string = ''.join([s for i, s in enumerate(string) if i not in flat_unk_ids]) |
|
return upd_string, unk_list |
|
|
|
def _norm_string(self, string): |
|
|
|
|
|
if len(string) == 0: |
|
return string |
|
string, unk_list = self.keep_unknown(string) |
|
|
|
token_src_list = [self.src_vocab.get(s, self.unk_index) for s in list(string)] |
|
src = token_src_list + [self.eos_index] + [self.pad_index] |
|
|
|
src2tgt = [self.src2tgt[s] for s in src] |
|
src2tgt = torch.LongTensor(src2tgt).to(self.device) |
|
|
|
src = torch.LongTensor(src).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
out = self.model(src, src2tgt) |
|
pred_words = self.decode_words(out, unk_list) |
|
if len(pred_words) > 199: |
|
warnings.warn("Sentence {} is too long".format(string), Warning) |
|
return pred_words |
|
|
|
def norm_text(self, text): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
splitters = '\n\t?!' |
|
parts = [p for p in re.split(r'({})'.format('|\\'.join(splitters)), text) if p != ''] |
|
norm_parts = [] |
|
for part in tqdm(parts): |
|
if part in splitters: |
|
norm_parts.append(part) |
|
else: |
|
weighted_string = [7 if symb.isdigit() else 1 for symb in part] |
|
if sum(weighted_string) <= self.max_len: |
|
norm_parts.append(self._norm_string(part)) |
|
else: |
|
spaces = [m.start() for m in re.finditer(' ', part)] |
|
start_point = 0 |
|
end_point = 0 |
|
curr_point = 0 |
|
|
|
while start_point < len(part): |
|
if curr_point in spaces: |
|
if sum(weighted_string[start_point:curr_point]) < self.max_len: |
|
end_point = curr_point + 1 |
|
else: |
|
norm_parts.append(self._norm_string(part[start_point:end_point])) |
|
start_point = end_point |
|
|
|
elif sum(weighted_string[end_point:curr_point]) >= self.max_len: |
|
if end_point > start_point: |
|
norm_parts.append(self._norm_string(part[start_point:end_point])) |
|
start_point = end_point |
|
end_point = curr_point - 1 |
|
norm_parts.append(self._norm_string(part[start_point:end_point])) |
|
start_point = end_point |
|
elif curr_point == len(part): |
|
norm_parts.append(self._norm_string(part[start_point:])) |
|
start_point = len(part) |
|
|
|
curr_point += 1 |
|
return ''.join(norm_parts) |
|
|
|
def decode_words(self, pred, unk_list=None): |
|
if unk_list is None: |
|
unk_list = [] |
|
pred = pred.cpu().numpy() |
|
pred_words = "".join(self.lookup_words(x=pred, |
|
vocab={i: w for w, i in self.tgt_vocab.items()}, |
|
unk_list=unk_list)) |
|
return pred_words |
|
|
|
def lookup_words(self, x, vocab, unk_list=None): |
|
if unk_list is None: |
|
unk_list = [] |
|
result = [] |
|
for i in x: |
|
if i == self.unk_index: |
|
if len(unk_list) > 0: |
|
result.append(unk_list.pop(0)) |
|
else: |
|
continue |
|
else: |
|
result.append(vocab[i]) |
|
return [str(t) for t in result] |
|
|