indiejoseph commited on
Commit
0e5205e
1 Parent(s): a99de3c

Create bart_pipeline.py

Browse files
Files changed (1) hide show
  1. bart_pipeline.py +99 -0
bart_pipeline.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TranslationPipeline
2
+ from transformers.pipelines.text2text_generation import ReturnType
3
+ from transformers import BartForConditionalGeneration, BertTokenizer
4
+ import logging
5
+ import re
6
+
7
+
8
+ def fix_chinese_text_generation_space(text):
9
+ output_text = text
10
+ output_text = re.sub(
11
+ r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([^0-9a-zA-Z])', r'\1\2', output_text)
12
+ output_text = re.sub(
13
+ r'([^0-9a-zA-Z])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
14
+ output_text = re.sub(
15
+ r'([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])\s([a-zA-Z0-9])', r'\1\2', output_text)
16
+ output_text = re.sub(
17
+ r'([a-zA-Z0-9])\s([\u3401-\u9FFF+——!,。?、~@#¥%…&*():;《)《》“”()»〔〕\-!$^*()_+|~=`{}\[\]:";\'<>?,.\/\\])', r'\1\2', output_text)
18
+ output_text = re.sub(r'$\s([0-9])', r'$\1', output_text)
19
+ output_text = re.sub(',', ',', output_text)
20
+ output_text = re.sub(r'([0-9]),([0-9])', r'\1,\2',
21
+ output_text) # fix comma in numbers
22
+ # fix multiple commas
23
+ output_text = re.sub(r'\s?[,]+\s?', ',', output_text)
24
+ output_text = re.sub(r'\s?[、]+\s?', '、', output_text)
25
+ # fix period
26
+ output_text = re.sub(r'\s?[。]+\s?', '。', output_text)
27
+ # fix ...
28
+ output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
29
+ # fix exclamation mark
30
+ output_text = re.sub(r'\s?[!!]+\s?', '!', output_text)
31
+ # fix question mark
32
+ output_text = re.sub(r'\s?[??]+\s?', '?', output_text)
33
+ # fix colon
34
+ output_text = re.sub(r'\s?[::]+\s?', ':', output_text)
35
+ # fix quotation mark
36
+ output_text = re.sub(r'\s?(["“”\']+)\s?', r'\1', output_text)
37
+ # fix semicolon
38
+ output_text = re.sub(r'\s?[;;]+\s?', ';', output_text)
39
+ # fix dots
40
+ output_text = re.sub(r'\s?([~●.…]+)\s?', r'\1', output_text)
41
+ output_text = re.sub(r'\s?\[…\]\s?', '', output_text)
42
+ output_text = re.sub(r'\s?\[\.\.\.\]\s?', '', output_text)
43
+ output_text = re.sub(r'\s?\.{3,}\s?', '...', output_text)
44
+ # fix slash
45
+ output_text = re.sub(r'\s?[//]+\s?', '/', output_text)
46
+ # fix dollar sign
47
+ output_text = re.sub(r'\s?[$$]+\s?', '$', output_text)
48
+ # fix @
49
+ output_text = re.sub(r'\s?([@@]+)\s?', '@', output_text)
50
+ # fix baskets
51
+ output_text = re.sub(
52
+ r'\s?([\[\(<〖【「『()』」】〗>\)\]]+)\s?', r'\1', output_text)
53
+
54
+ return output_text
55
+
56
+
57
+ class BartPipeline(TranslationPipeline):
58
+ def __init__(self,
59
+ model_name_or_path: str = "indiejoseph/bart-base-cantonese",
60
+ device=None,
61
+ max_length=512,
62
+ src_lang=None,
63
+ tgt_lang=None):
64
+ self.model_name_or_path = model_name_or_path
65
+ self.tokenizer = self._load_tokenizer()
66
+ self.model = self._load_model()
67
+ self.model.eval()
68
+ super().__init__(self.model, self.tokenizer, device=device,
69
+ max_length=max_length, src_lang=src_lang, tgt_lang=tgt_lang)
70
+
71
+ def _load_tokenizer(self):
72
+ return BertTokenizer.from_pretrained(self.model_name_or_path)
73
+
74
+ def _load_model(self):
75
+ return BartForConditionalGeneration.from_pretrained(self.model_name_or_path)
76
+
77
+ def postprocess(
78
+ self,
79
+ model_outputs,
80
+ return_type=ReturnType.TEXT,
81
+ clean_up_tokenization_spaces=True,
82
+ ):
83
+ records = super().postprocess(
84
+ model_outputs,
85
+ return_type=return_type,
86
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
87
+ )
88
+ for rec in records:
89
+ translation_text = fix_chinese_text_generation_space(
90
+ rec["translation_text"].strip())
91
+
92
+ rec["translation_text"] = translation_text
93
+ return records
94
+
95
+
96
+ if __name__ == '__main__':
97
+ pipe = BartPipeline(device=0)
98
+
99
+ print(pipe('哈哈,我正在努力研究緊個問題。不過,邊個知呢,可能哪一日我會諗到一個好主意去實現到佢。', max_length=100))