Macropodus commited on
Commit
bd24f17
·
verified ·
1 Parent(s): 760a845

preprocess

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py CHANGED
@@ -6,6 +6,7 @@ import re
6
 
7
  from transformers import BertTokenizer, BertForMaskedLM
8
  import gradio as gr
 
9
  import torch
10
 
11
 
@@ -20,6 +21,10 @@ vocab = tokenizer.vocab
20
  # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
21
  # model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
22
  # vocab = tokenizer.vocab
 
 
 
 
23
 
24
 
25
  def func_macro_correct(text):
@@ -127,6 +132,27 @@ def func_macro_correct(text):
127
  return line_dict
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def cut_sent_by_stay(text, return_length=True, add_semicolon=False):
131
  """ 分句但是保存原标点符号 """
132
  if add_semicolon:
@@ -157,6 +183,27 @@ def cut_sent_by_stay(text, return_length=True, add_semicolon=False):
157
  if return_length:
158
  return text_cut, text_length_s
159
  return text_cut
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  def func_macro_correct_long(text):
@@ -165,6 +212,12 @@ def func_macro_correct_long(text):
165
  text_correct = ""
166
  errors_new = []
167
  for idx, text in enumerate(texts):
 
 
 
 
 
 
168
  text_out = func_macro_correct(text)
169
  source = text_out.get("source")
170
  target = text_out.get("target")
 
6
 
7
  from transformers import BertTokenizer, BertForMaskedLM
8
  import gradio as gr
9
+ import opencc
10
  import torch
11
 
12
 
 
21
  # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
22
  # model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
23
  # vocab = tokenizer.vocab
24
+ converter_t2s = opencc.OpenCC("t2s.json")
25
+ context = converter_t2s.convert("汉字") # 漢字
26
+ PUN_EN2ZH_DICT = {",": ",", ";": ";", "!": "!", "?": "?", ":": ":", "(": "(", ")": ")", "_": "—"}
27
+ PUN_BERT_DICT = {"“":'"', "”":'"', "‘":'"', "’":'"', "—": "_", "——": "__"}
28
 
29
 
30
  def func_macro_correct(text):
 
132
  return line_dict
133
 
134
 
135
+ def transfor_english_symbol_to_chinese(text, kv_dict=PUN_EN2ZH_DICT):
136
+ """ 将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化 """
137
+ for k, v in kv_dict.items(): # 英文替换
138
+ text = text.replace(k, v)
139
+ if text and text[-1] == ".": # 最后一个字符是英文.
140
+ text = text[:-1] + "。"
141
+
142
+ if text and "\"" in text: # 双引号
143
+ index_list = [i.start() for i in re.finditer("\"", text)]
144
+ if index_list:
145
+ for idx, index in enumerate(index_list):
146
+ symbol = "“" if idx % 2 == 0 else "”"
147
+ text = text[:index] + symbol + text[index + 1:]
148
+
149
+ if text and "'" in text: # 单引号
150
+ index_list = [i.start() for i in re.finditer("'", text)]
151
+ if index_list:
152
+ for idx, index in enumerate(index_list):
153
+ symbol = "‘" if idx % 2 == 0 else "’"
154
+ text = text[:index] + symbol + text[index + 1:]
155
+ return text
156
  def cut_sent_by_stay(text, return_length=True, add_semicolon=False):
157
  """ 分句但是保存原标点符号 """
158
  if add_semicolon:
 
183
  if return_length:
184
  return text_cut, text_length_s
185
  return text_cut
186
+ def transfor_bert_unk_pun_to_know(text, kv_dict=PUN_BERT_DICT):
187
+ """ 将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化 """
188
+ for k, v in kv_dict.items(): # 英文替换
189
+ text = text.replace(k, v)
190
+ return text
191
+ def tradition_to_simple(text):
192
+ """ 繁体到简体 """
193
+ return converter_t2s.convert(text)
194
+ def string_q2b(ustring):
195
+ """把字符串全角转半角"""
196
+ return "".join([q2b(uchar) for uchar in ustring])
197
+ def q2b(uchar):
198
+ """全角转半角"""
199
+ inside_code = ord(uchar)
200
+ if inside_code == 0x3000:
201
+ inside_code = 0x0020
202
+ else:
203
+ inside_code -= 0xfee0
204
+ if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
205
+ return uchar
206
+ return chr(inside_code)
207
 
208
 
209
  def func_macro_correct_long(text):
 
212
  text_correct = ""
213
  errors_new = []
214
  for idx, text in enumerate(texts):
215
+ # 前处理
216
+ text = transfor_english_symbol_to_chinese(text)
217
+ text = string_q2b(text)
218
+ text = tradition_to_simple(text)
219
+ text = transfor_bert_unk_pun_to_know(text)
220
+
221
  text_out = func_macro_correct(text)
222
  source = text_out.get("source")
223
  target = text_out.get("target")