Derur's picture
Upload 6 files
69da708 verified
import re
import torch
from string import printable, punctuation
from tqdm import tqdm
import warnings
class Normalizer:
def __init__(self,
device='cpu',
jit_model='jit_s2s.pt'):
super(Normalizer, self).__init__()
self.device = torch.device(device)
self.init_vocabs()
self.model = torch.jit.load(jit_model, map_location=device)
self.model.eval()
self.max_len = 150
def init_vocabs(self):
# Initializes source and target vocabularies
# vocabs
rus_letters = 'абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ'
spec_symbols = '¼³№¾⅞½⅔⅓⅛⅜²'
# numbers + eng + punctuation + space + rus
self.src_vocab = {token: i + 5 for i, token in enumerate(printable[:-5] + rus_letters + '«»—' + spec_symbols)}
# punctuation + space + rus
self.tgt_vocab = {token: i + 5 for i, token in enumerate(punctuation + rus_letters + ' ' + '«»—')}
unk = '#UNK#'
pad = '#PAD#'
sos = '#SOS#'
eos = '#EOS#'
tfo = '#TFO#'
for i, token in enumerate([unk, pad, sos, eos, tfo]):
self.src_vocab[token] = i
self.tgt_vocab[token] = i
for i, token_name in enumerate(['unk', 'pad', 'sos', 'eos', 'tfo']):
setattr(self, '{}_index'.format(token_name), i)
inv_src_vocab = {v: k for k, v in self.src_vocab.items()}
self.src2tgt = {src_i: self.tgt_vocab.get(src_symb, -1) for src_i, src_symb in inv_src_vocab.items()}
def keep_unknown(self, string):
reg = re.compile(r'[^{}]+'.format(''.join(self.src_vocab.keys())))
unk_list = re.findall(reg, string)
unk_ids = [range(m.start() + 1, m.end()) for m in re.finditer(reg, string) if m.end() - m.start() > 1]
flat_unk_ids = [i for sublist in unk_ids for i in sublist]
upd_string = ''.join([s for i, s in enumerate(string) if i not in flat_unk_ids])
return upd_string, unk_list
def _norm_string(self, string):
# Normalizes chunk
if len(string) == 0:
return string
string, unk_list = self.keep_unknown(string)
token_src_list = [self.src_vocab.get(s, self.unk_index) for s in list(string)]
src = token_src_list + [self.eos_index] + [self.pad_index]
src2tgt = [self.src2tgt[s] for s in src]
src2tgt = torch.LongTensor(src2tgt).to(self.device)
src = torch.LongTensor(src).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.model(src, src2tgt)
pred_words = self.decode_words(out, unk_list)
if len(pred_words) > 199:
warnings.warn("Sentence {} is too long".format(string), Warning)
return pred_words
def norm_text(self, text):
# Normalizes text
# Splits sentences to small chunks with weighted length <= max_len:
# * weighted length - estimated length of normalized sentence
#
# 1. Full text is splitted by "ending" symbols (\n\t?!.) to sentences;
# 2. Long sentences additionally splitted to chunks: by spaces or just dividing too long words
splitters = '\n\t?!'
parts = [p for p in re.split(r'({})'.format('|\\'.join(splitters)), text) if p != '']
norm_parts = []
for part in tqdm(parts):
if part in splitters:
norm_parts.append(part)
else:
weighted_string = [7 if symb.isdigit() else 1 for symb in part]
if sum(weighted_string) <= self.max_len:
norm_parts.append(self._norm_string(part))
else:
spaces = [m.start() for m in re.finditer(' ', part)]
start_point = 0
end_point = 0
curr_point = 0
while start_point < len(part):
if curr_point in spaces:
if sum(weighted_string[start_point:curr_point]) < self.max_len:
end_point = curr_point + 1
else:
norm_parts.append(self._norm_string(part[start_point:end_point]))
start_point = end_point
elif sum(weighted_string[end_point:curr_point]) >= self.max_len:
if end_point > start_point:
norm_parts.append(self._norm_string(part[start_point:end_point]))
start_point = end_point
end_point = curr_point - 1
norm_parts.append(self._norm_string(part[start_point:end_point]))
start_point = end_point
elif curr_point == len(part):
norm_parts.append(self._norm_string(part[start_point:]))
start_point = len(part)
curr_point += 1
return ''.join(norm_parts)
def decode_words(self, pred, unk_list=None):
if unk_list is None:
unk_list = []
pred = pred.cpu().numpy()
pred_words = "".join(self.lookup_words(x=pred,
vocab={i: w for w, i in self.tgt_vocab.items()},
unk_list=unk_list))
return pred_words
def lookup_words(self, x, vocab, unk_list=None):
if unk_list is None:
unk_list = []
result = []
for i in x:
if i == self.unk_index:
if len(unk_list) > 0:
result.append(unk_list.pop(0))
else:
continue
else:
result.append(vocab[i])
return [str(t) for t in result]