Derur commited on
Commit
69da708
·
verified ·
1 Parent(s): ab3dfc8

Upload 6 files

Browse files
ru_stt_text_normalization/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Normalization](https://pics.spark-in.me/upload/7c12fea58ff515ffb46df52b6050ace0.jpg)
2
+
3
+ # Russian STT Text Normalization
4
+
5
+ Russian text normalization pipeline for speech-to-text and other applications based on tagging s2s networks.
6
+
7
+ ## Requirements
8
+
9
+ - Python >= 3.6
10
+ - [PyTorch](https://pytorch.org/get-started/locally/) >= 1.4 for s2s pipeline
11
+ - [tqdm](https://github.com/tqdm/tqdm) for progress bar
12
+
13
+ ```
14
+ pip install torch
15
+ pip install tqdm
16
+ ```
17
+
18
+ ## Usage
19
+
20
+ ```python
21
+ from normalizer import Normalizer
22
+
23
+ text = 'С 12.01.1943 г. площадь сельсовета — 1785,5 га.'
24
+
25
+ norm = Normalizer()
26
+ result = norm.norm_text(text)
27
+ print(result)
28
+ ```
29
+
30
+ ```
31
+ >>> С двенадцатого января тысяча девятьсот сорок третьего года площадь сельсовета
32
+ >>> — тысяча семьсот восемьдесят пять целых и пять десятых гектара
33
+ ```
ru_stt_text_normalization/__init__.py ADDED
File without changes
ru_stt_text_normalization/jit_s2s.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a21bfbbe6b0392cbeff97f400cf27bfc37f010220df49a435b8eeb1363e2797
3
+ size 3766801
ru_stt_text_normalization/normalizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from string import printable, punctuation
4
+ from tqdm import tqdm
5
+ import warnings
6
+
7
+
8
+ class Normalizer:
9
+ def __init__(self,
10
+ device='cpu',
11
+ jit_model='jit_s2s.pt'):
12
+ super(Normalizer, self).__init__()
13
+
14
+ self.device = torch.device(device)
15
+
16
+ self.init_vocabs()
17
+
18
+ self.model = torch.jit.load(jit_model, map_location=device)
19
+ self.model.eval()
20
+ self.max_len = 150
21
+
22
+ def init_vocabs(self):
23
+ # Initializes source and target vocabularies
24
+
25
+ # vocabs
26
+ rus_letters = 'абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ'
27
+ spec_symbols = '¼³№¾⅞½⅔⅓⅛⅜²'
28
+ # numbers + eng + punctuation + space + rus
29
+ self.src_vocab = {token: i + 5 for i, token in enumerate(printable[:-5] + rus_letters + '«»—' + spec_symbols)}
30
+ # punctuation + space + rus
31
+ self.tgt_vocab = {token: i + 5 for i, token in enumerate(punctuation + rus_letters + ' ' + '«»—')}
32
+
33
+ unk = '#UNK#'
34
+ pad = '#PAD#'
35
+ sos = '#SOS#'
36
+ eos = '#EOS#'
37
+ tfo = '#TFO#'
38
+ for i, token in enumerate([unk, pad, sos, eos, tfo]):
39
+ self.src_vocab[token] = i
40
+ self.tgt_vocab[token] = i
41
+
42
+ for i, token_name in enumerate(['unk', 'pad', 'sos', 'eos', 'tfo']):
43
+ setattr(self, '{}_index'.format(token_name), i)
44
+
45
+ inv_src_vocab = {v: k for k, v in self.src_vocab.items()}
46
+ self.src2tgt = {src_i: self.tgt_vocab.get(src_symb, -1) for src_i, src_symb in inv_src_vocab.items()}
47
+
48
+ def keep_unknown(self, string):
49
+ reg = re.compile(r'[^{}]+'.format(''.join(self.src_vocab.keys())))
50
+ unk_list = re.findall(reg, string)
51
+
52
+ unk_ids = [range(m.start() + 1, m.end()) for m in re.finditer(reg, string) if m.end() - m.start() > 1]
53
+ flat_unk_ids = [i for sublist in unk_ids for i in sublist]
54
+
55
+ upd_string = ''.join([s for i, s in enumerate(string) if i not in flat_unk_ids])
56
+ return upd_string, unk_list
57
+
58
+ def _norm_string(self, string):
59
+ # Normalizes chunk
60
+
61
+ if len(string) == 0:
62
+ return string
63
+ string, unk_list = self.keep_unknown(string)
64
+
65
+ token_src_list = [self.src_vocab.get(s, self.unk_index) for s in list(string)]
66
+ src = token_src_list + [self.eos_index] + [self.pad_index]
67
+
68
+ src2tgt = [self.src2tgt[s] for s in src]
69
+ src2tgt = torch.LongTensor(src2tgt).to(self.device)
70
+
71
+ src = torch.LongTensor(src).unsqueeze(0).to(self.device)
72
+ with torch.no_grad():
73
+ out = self.model(src, src2tgt)
74
+ pred_words = self.decode_words(out, unk_list)
75
+ if len(pred_words) > 199:
76
+ warnings.warn("Sentence {} is too long".format(string), Warning)
77
+ return pred_words
78
+
79
+ def norm_text(self, text):
80
+ # Normalizes text
81
+
82
+ # Splits sentences to small chunks with weighted length <= max_len:
83
+ # * weighted length - estimated length of normalized sentence
84
+ #
85
+ # 1. Full text is splitted by "ending" symbols (\n\t?!.) to sentences;
86
+ # 2. Long sentences additionally splitted to chunks: by spaces or just dividing too long words
87
+
88
+ splitters = '\n\t?!'
89
+ parts = [p for p in re.split(r'({})'.format('|\\'.join(splitters)), text) if p != '']
90
+ norm_parts = []
91
+ for part in tqdm(parts):
92
+ if part in splitters:
93
+ norm_parts.append(part)
94
+ else:
95
+ weighted_string = [7 if symb.isdigit() else 1 for symb in part]
96
+ if sum(weighted_string) <= self.max_len:
97
+ norm_parts.append(self._norm_string(part))
98
+ else:
99
+ spaces = [m.start() for m in re.finditer(' ', part)]
100
+ start_point = 0
101
+ end_point = 0
102
+ curr_point = 0
103
+
104
+ while start_point < len(part):
105
+ if curr_point in spaces:
106
+ if sum(weighted_string[start_point:curr_point]) < self.max_len:
107
+ end_point = curr_point + 1
108
+ else:
109
+ norm_parts.append(self._norm_string(part[start_point:end_point]))
110
+ start_point = end_point
111
+
112
+ elif sum(weighted_string[end_point:curr_point]) >= self.max_len:
113
+ if end_point > start_point:
114
+ norm_parts.append(self._norm_string(part[start_point:end_point]))
115
+ start_point = end_point
116
+ end_point = curr_point - 1
117
+ norm_parts.append(self._norm_string(part[start_point:end_point]))
118
+ start_point = end_point
119
+ elif curr_point == len(part):
120
+ norm_parts.append(self._norm_string(part[start_point:]))
121
+ start_point = len(part)
122
+
123
+ curr_point += 1
124
+ return ''.join(norm_parts)
125
+
126
+ def decode_words(self, pred, unk_list=None):
127
+ if unk_list is None:
128
+ unk_list = []
129
+ pred = pred.cpu().numpy()
130
+ pred_words = "".join(self.lookup_words(x=pred,
131
+ vocab={i: w for w, i in self.tgt_vocab.items()},
132
+ unk_list=unk_list))
133
+ return pred_words
134
+
135
+ def lookup_words(self, x, vocab, unk_list=None):
136
+ if unk_list is None:
137
+ unk_list = []
138
+ result = []
139
+ for i in x:
140
+ if i == self.unk_index:
141
+ if len(unk_list) > 0:
142
+ result.append(unk_list.pop(0))
143
+ else:
144
+ continue
145
+ else:
146
+ result.append(vocab[i])
147
+ return [str(t) for t in result]
ru_stt_text_normalization/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.4.0
2
+ tqdm
ru_stt_text_normalization/ru_stt_text_normalization.7z ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:052f4f10e225266c2bdff939d2bd5459699b74082b54e173646867421f20a80f
3
+ size 3157061