File size: 13,493 Bytes
cb10a62
 
 
760a845
 
 
cb10a62
760a845
bd24f17
760a845
cb10a62
 
 
 
 
 
 
 
6b2b9b7
 
 
 
 
bd24f17
 
 
 
6b2b9b7
 
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2b9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2b9b7
 
 
 
 
 
 
 
cb10a62
760a845
 
 
 
 
bd24f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760a845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd24f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760a845
 
 
 
 
 
 
 
bd24f17
 
 
 
 
 
760a845
 
 
 
 
 
 
 
 
 
cb10a62
 
 
760a845
 
 
 
 
 
 
 
 
 
 
cb10a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760a845
cb10a62
 
 
 
 
 
760a845
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# -*- coding: utf-8 -*-

import operator
import copy
import re

from transformers import BertTokenizer, BertForMaskedLM
import gradio as gr
import opencc
import torch


pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v2"
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
vocab = tokenizer.vocab


# from modelscope import AutoTokenizer, AutoModelForMaskedLM
# pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
# model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
# vocab = tokenizer.vocab
converter_t2s = opencc.OpenCC("t2s.json")
context = converter_t2s.convert("汉字")  # 漢字
PUN_EN2ZH_DICT = {",": ",", ";": ";", "!": "!", "?": "?", ":": ":", "(": "(", ")": ")", "_": "—"}
PUN_BERT_DICT = {"“":'"', "”":'"', "‘":'"', "’":'"', "—": "_", "——": "__"}


def func_macro_correct(text):
    with torch.no_grad():
        outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))

    def flag_total_chinese(text):
        """
        judge is total chinese or not, 判断是不是全是中文
        Args:
            text: str, eg. "macadam, 碎石路"
        Returns:
            bool, True or False
        """
        for word in text:
            if not "\u4e00" <= word <= "\u9fa5":
                return False
        return True

    def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
        """Get errors between corrected text and origin text
        code from:  https://github.com/shibing624/pycorrector
        """
        new_corrected_text = ""
        errors = []
        i, j = 0, 0
        unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
        while i < len(origin_text) and j < len(corrected_text):
            if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
                new_corrected_text += origin_text[i]
                i += 1
            elif corrected_text[j] in unk_tokens:
                new_corrected_text += corrected_text[j]
                j += 1
            # Deal with Chinese characters
            elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
                # If the two characters are the same, then the two pointers move forward together
                if origin_text[i] == corrected_text[j]:
                    new_corrected_text += corrected_text[j]
                    i += 1
                    j += 1
                else:
                    # Check for insertion errors
                    if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
                        errors.append(('', corrected_text[j], j))
                        new_corrected_text += corrected_text[j]
                        j += 1
                    # Check for deletion errors
                    elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
                        errors.append((origin_text[i], '', i))
                        i += 1
                    # Check for replacement errors
                    else:
                        errors.append((origin_text[i], corrected_text[j], i))
                        new_corrected_text += corrected_text[j]
                        i += 1
                        j += 1
            else:
                new_corrected_text += origin_text[i]
                if origin_text[i] == corrected_text[j]:
                    j += 1
                i += 1
        errors = sorted(errors, key=operator.itemgetter(2))
        return new_corrected_text, errors

    def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
        """Get new corrected text and errors between corrected text and origin text
        code from:  https://github.com/shibing624/pycorrector
        """
        errors = []
        unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '', ',']

        for i, ori_char in enumerate(origin_text):
            if i >= len(corrected_text):
                continue
            if ori_char in unk_tokens or ori_char not in know_tokens:
                # deal with unk word
                corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                continue
            if ori_char != corrected_text[i]:
                if not flag_total_chinese(ori_char):
                    # pass not chinese char
                    corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                    continue
                if not flag_total_chinese(corrected_text[i]):
                    corrected_text = corrected_text[:i] + corrected_text[i + 1:]
                    continue
                errors.append([ori_char, corrected_text[i], i])
        errors = sorted(errors, key=operator.itemgetter(2))
        return corrected_text, errors

    _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
    corrected_text = _text[:len(text)]
    print("#" * 128)
    print(text)
    print(corrected_text)
    print(len(text), len(corrected_text))
    if len(corrected_text) == len(text):
        corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
    else:
        corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
    print(text, ' => ', corrected_text, details)
    # return corrected_text + ' ' + str(details)
    line_dict = {"source": text, "target": corrected_text, "errors": details}
    return line_dict


def transfor_english_symbol_to_chinese(text, kv_dict=PUN_EN2ZH_DICT):
    """   将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化   """
    for k, v in kv_dict.items():  # 英文替换
        text = text.replace(k, v)
    if text and text[-1] == ".":  # 最后一个字符是英文.
        text = text[:-1] + "。"

    if text and "\"" in text:  # 双引号
        index_list = [i.start() for i in re.finditer("\"", text)]
        if index_list:
            for idx, index in enumerate(index_list):
                symbol = "“" if idx % 2 == 0 else "”"
                text = text[:index] + symbol + text[index + 1:]

    if text and "'" in text:  # 单引号
        index_list = [i.start() for i in re.finditer("'", text)]
        if index_list:
            for idx, index in enumerate(index_list):
                symbol = "‘" if idx % 2 == 0 else "’"
                text = text[:index] + symbol + text[index + 1:]
    return text
def cut_sent_by_stay(text, return_length=True, add_semicolon=False):
    """  分句但是保存原标点符号  """
    if add_semicolon:
        text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|;|!|?|。|…|\!|\?", text)
        conn_symbol = ";!?。…”;!?》)\n"
    else:
        text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|!|?|。|…|\!|\?", text)
        conn_symbol = "!?。…”!?》)\n"
    text_length_s = []
    text_cut = []
    len_text = len(text) - 1
    # signal_symbol = "—”>;?…)‘《’(·》“~,、!。:<"
    len_global = 0
    for idx, text_sp_i in enumerate(text_sp):
        text_cut_idx = text_sp[idx]
        len_global_before = copy.deepcopy(len_global)
        len_global += len(text_sp_i)
        while True:
            if len_global <= len_text and text[len_global] in conn_symbol:
                text_cut_idx += text[len_global]
            else:
                # len_global += 1
                if text_cut_idx:
                    text_length_s.append([len_global_before, len_global])
                    text_cut.append(text_cut_idx)
                break
            len_global += 1
    if return_length:
        return text_cut, text_length_s
    return text_cut
def transfor_bert_unk_pun_to_know(text, kv_dict=PUN_BERT_DICT):
    """   将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化   """
    for k, v in kv_dict.items():  # 英文替换
        text = text.replace(k, v)
    return text
def tradition_to_simple(text):
    """  繁体到简体  """
    return converter_t2s.convert(text)
def string_q2b(ustring):
    """把字符串全角转半角"""
    return "".join([q2b(uchar) for uchar in ustring])
def q2b(uchar):
    """全角转半角"""
    inside_code = ord(uchar)
    if inside_code == 0x3000:
        inside_code = 0x0020
    else:
        inside_code -= 0xfee0
    if inside_code < 0x0020 or inside_code > 0x7e:  # 转完之后不是半角字符返回原来的字符
        return uchar
    return chr(inside_code)


def func_macro_correct_long(text):
    """   长句   """
    texts, length = cut_sent_by_stay(text, return_length=True, add_semicolon=True)
    text_correct = ""
    errors_new = []
    for idx, text in enumerate(texts):
        # 前处理
        text = transfor_english_symbol_to_chinese(text)
        text = string_q2b(text)
        text = tradition_to_simple(text)
        text = transfor_bert_unk_pun_to_know(text)

        text_out = func_macro_correct(text)
        source = text_out.get("source")
        target = text_out.get("target")
        errors = text_out.get("errors")
        text_correct += target
        for error in errors:
            pos = length[idx][0] + error[-1]
            error_1 = [error[0], error[1], pos]
            errors_new.append(error_1)
    return text_correct + '\n' + str(errors_new)


if __name__ == '__main__':
    text = """网购的烦脑
emer 发布于 2025-7-3 18:20 阅读:73

最近网购遇到件恼火的事。我在网店看中件羽戎服,店家保正是正品,还承诺七天无里由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。

联系客服时,对方态度敷衔,先说让我自行缝补,后又说要扣除运废才给退。我在评沦区如实描述经历,结果发现好多消废者都有类似遭遇。

这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复合对商品信息。
网购的烦恼发布于2025-7-310期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。
网购的烦恼e发布于2025-7-3期期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。网购的烦恼发布于2025-7-310期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。"""
    print(func_macro_correct_long(text))

    examples = [
        "夫谷之雨,犹复云之亦从的起,因与疾风俱飘,参于天,集于的。",
        "机七学习是人工智能领遇最能体现智能的一个分知",
        '他们的吵翻很不错,再说他们做的咖喱鸡也好吃',
        "抗疫路上,除了提心吊胆也有难的得欢笑。",
        "我是练习时长两念半的鸽仁练习生蔡徐坤",
        "清晨,如纱一般地薄雾笼罩着世界。",
        "得府许我立庙于此,故请君移去尔。",
        "他法语说的很好,的语也不错",
        "遇到一位很棒的奴生跟我疗天",
        "五年级得数学,我考的很差。",
        "我们为这个目标努力不解",
        '今天兴情很好',
    ]

    gr.Interface(
        func_macro_correct_long,
        inputs='text',
        outputs='text',
        title="Chinese Spelling Correction Model Macropodus/macbert4mdcspell_v2",
        description="Copy or input error Chinese text. Submit and the machine will correct text.",
        article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
        examples=examples
    ).launch()