Spaces:
Running
Running
File size: 13,493 Bytes
cb10a62 760a845 cb10a62 760a845 bd24f17 760a845 cb10a62 6b2b9b7 bd24f17 6b2b9b7 cb10a62 6b2b9b7 cb10a62 6b2b9b7 cb10a62 760a845 bd24f17 760a845 bd24f17 760a845 bd24f17 760a845 cb10a62 760a845 cb10a62 760a845 cb10a62 760a845 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
# -*- coding: utf-8 -*-
import operator
import copy
import re
from transformers import BertTokenizer, BertForMaskedLM
import gradio as gr
import opencc
import torch
pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v2"
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
vocab = tokenizer.vocab
# from modelscope import AutoTokenizer, AutoModelForMaskedLM
# pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
# vocab = tokenizer.vocab
converter_t2s = opencc.OpenCC("t2s.json")
context = converter_t2s.convert("汉字") # 漢字
PUN_EN2ZH_DICT = {",": ",", ";": ";", "!": "!", "?": "?", ":": ":", "(": "(", ")": ")", "_": "—"}
PUN_BERT_DICT = {"“":'"', "”":'"', "‘":'"', "’":'"', "—": "_", "——": "__"}
def func_macro_correct(text):
with torch.no_grad():
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
def flag_total_chinese(text):
"""
judge is total chinese or not, 判断是不是全是中文
Args:
text: str, eg. "macadam, 碎石路"
Returns:
bool, True or False
"""
for word in text:
if not "\u4e00" <= word <= "\u9fa5":
return False
return True
def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
"""Get errors between corrected text and origin text
code from: https://github.com/shibing624/pycorrector
"""
new_corrected_text = ""
errors = []
i, j = 0, 0
unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
while i < len(origin_text) and j < len(corrected_text):
if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
new_corrected_text += origin_text[i]
i += 1
elif corrected_text[j] in unk_tokens:
new_corrected_text += corrected_text[j]
j += 1
# Deal with Chinese characters
elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
# If the two characters are the same, then the two pointers move forward together
if origin_text[i] == corrected_text[j]:
new_corrected_text += corrected_text[j]
i += 1
j += 1
else:
# Check for insertion errors
if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
errors.append(('', corrected_text[j], j))
new_corrected_text += corrected_text[j]
j += 1
# Check for deletion errors
elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
errors.append((origin_text[i], '', i))
i += 1
# Check for replacement errors
else:
errors.append((origin_text[i], corrected_text[j], i))
new_corrected_text += corrected_text[j]
i += 1
j += 1
else:
new_corrected_text += origin_text[i]
if origin_text[i] == corrected_text[j]:
j += 1
i += 1
errors = sorted(errors, key=operator.itemgetter(2))
return new_corrected_text, errors
def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
"""Get new corrected text and errors between corrected text and origin text
code from: https://github.com/shibing624/pycorrector
"""
errors = []
unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '', ',']
for i, ori_char in enumerate(origin_text):
if i >= len(corrected_text):
continue
if ori_char in unk_tokens or ori_char not in know_tokens:
# deal with unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
if ori_char != corrected_text[i]:
if not flag_total_chinese(ori_char):
# pass not chinese char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
if not flag_total_chinese(corrected_text[i]):
corrected_text = corrected_text[:i] + corrected_text[i + 1:]
continue
errors.append([ori_char, corrected_text[i], i])
errors = sorted(errors, key=operator.itemgetter(2))
return corrected_text, errors
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
print("#" * 128)
print(text)
print(corrected_text)
print(len(text), len(corrected_text))
if len(corrected_text) == len(text):
corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
else:
corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
print(text, ' => ', corrected_text, details)
# return corrected_text + ' ' + str(details)
line_dict = {"source": text, "target": corrected_text, "errors": details}
return line_dict
def transfor_english_symbol_to_chinese(text, kv_dict=PUN_EN2ZH_DICT):
""" 将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化 """
for k, v in kv_dict.items(): # 英文替换
text = text.replace(k, v)
if text and text[-1] == ".": # 最后一个字符是英文.
text = text[:-1] + "。"
if text and "\"" in text: # 双引号
index_list = [i.start() for i in re.finditer("\"", text)]
if index_list:
for idx, index in enumerate(index_list):
symbol = "“" if idx % 2 == 0 else "”"
text = text[:index] + symbol + text[index + 1:]
if text and "'" in text: # 单引号
index_list = [i.start() for i in re.finditer("'", text)]
if index_list:
for idx, index in enumerate(index_list):
symbol = "‘" if idx % 2 == 0 else "’"
text = text[:index] + symbol + text[index + 1:]
return text
def cut_sent_by_stay(text, return_length=True, add_semicolon=False):
""" 分句但是保存原标点符号 """
if add_semicolon:
text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|;|!|?|。|…|\!|\?", text)
conn_symbol = ";!?。…”;!?》)\n"
else:
text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|!|?|。|…|\!|\?", text)
conn_symbol = "!?。…”!?》)\n"
text_length_s = []
text_cut = []
len_text = len(text) - 1
# signal_symbol = "—”>;?…)‘《’(·》“~,、!。:<"
len_global = 0
for idx, text_sp_i in enumerate(text_sp):
text_cut_idx = text_sp[idx]
len_global_before = copy.deepcopy(len_global)
len_global += len(text_sp_i)
while True:
if len_global <= len_text and text[len_global] in conn_symbol:
text_cut_idx += text[len_global]
else:
# len_global += 1
if text_cut_idx:
text_length_s.append([len_global_before, len_global])
text_cut.append(text_cut_idx)
break
len_global += 1
if return_length:
return text_cut, text_length_s
return text_cut
def transfor_bert_unk_pun_to_know(text, kv_dict=PUN_BERT_DICT):
""" 将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化 """
for k, v in kv_dict.items(): # 英文替换
text = text.replace(k, v)
return text
def tradition_to_simple(text):
""" 繁体到简体 """
return converter_t2s.convert(text)
def string_q2b(ustring):
"""把字符串全角转半角"""
return "".join([q2b(uchar) for uchar in ustring])
def q2b(uchar):
"""全角转半角"""
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
return uchar
return chr(inside_code)
def func_macro_correct_long(text):
""" 长句 """
texts, length = cut_sent_by_stay(text, return_length=True, add_semicolon=True)
text_correct = ""
errors_new = []
for idx, text in enumerate(texts):
# 前处理
text = transfor_english_symbol_to_chinese(text)
text = string_q2b(text)
text = tradition_to_simple(text)
text = transfor_bert_unk_pun_to_know(text)
text_out = func_macro_correct(text)
source = text_out.get("source")
target = text_out.get("target")
errors = text_out.get("errors")
text_correct += target
for error in errors:
pos = length[idx][0] + error[-1]
error_1 = [error[0], error[1], pos]
errors_new.append(error_1)
return text_correct + '\n' + str(errors_new)
if __name__ == '__main__':
text = """网购的烦脑
emer 发布于 2025-7-3 18:20 阅读:73
最近网购遇到件恼火的事。我在网店看中件羽戎服,店家保正是正品,还承诺七天无里由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。
联系客服时,对方态度敷衔,先说让我自行缝补,后又说要扣除运废才给退。我在评沦区如实描述经历,结果发现好多消废者都有类似遭遇。
这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复合对商品信息。
网购的烦恼发布于2025-7-310期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。
网购的烦恼e发布于2025-7-3期期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。网购的烦恼发布于2025-7-310期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。"""
print(func_macro_correct_long(text))
examples = [
"夫谷之雨,犹复云之亦从的起,因与疾风俱飘,参于天,集于的。",
"机七学习是人工智能领遇最能体现智能的一个分知",
'他们的吵翻很不错,再说他们做的咖喱鸡也好吃',
"抗疫路上,除了提心吊胆也有难的得欢笑。",
"我是练习时长两念半的鸽仁练习生蔡徐坤",
"清晨,如纱一般地薄雾笼罩着世界。",
"得府许我立庙于此,故请君移去尔。",
"他法语说的很好,的语也不错",
"遇到一位很棒的奴生跟我疗天",
"五年级得数学,我考的很差。",
"我们为这个目标努力不解",
'今天兴情很好',
]
gr.Interface(
func_macro_correct_long,
inputs='text',
outputs='text',
title="Chinese Spelling Correction Model Macropodus/macbert4mdcspell_v2",
description="Copy or input error Chinese text. Submit and the machine will correct text.",
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
examples=examples
).launch() |