Update model.py
Browse files
model.py
CHANGED
@@ -69,7 +69,7 @@ class SperSignFilter(Filter):
|
|
69 |
for i, input_str in enumerate(inputs):
|
70 |
if '%s' in input_str:
|
71 |
encoded_str = input_str.replace('%s', '*')
|
72 |
-
self.code.append(i) # 将包含 's
|
73 |
else:
|
74 |
encoded_str = input_str
|
75 |
encoded_inputs.append(encoded_str)
|
@@ -80,6 +80,29 @@ class SperSignFilter(Filter):
|
|
80 |
for i in self.code:
|
81 |
decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') # 使用 self.code 中的索引还原原始字符串
|
82 |
return decoded_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
class ChevronsFilter(Filter):
|
85 |
def __init__(self):
|
@@ -238,7 +261,7 @@ class Model():
|
|
238 |
# self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
|
239 |
|
240 |
def generate(self, inputs, original_language, target_languages, max_batch_size):
|
241 |
-
filter_list = [SpecialTokenFilter(), SperSignFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
|
242 |
filter_pipeline = FilterPipeline(filter_list)
|
243 |
def language_mapping(original_language):
|
244 |
d = {
|
|
|
69 |
for i, input_str in enumerate(inputs):
|
70 |
if '%s' in input_str:
|
71 |
encoded_str = input_str.replace('%s', '*')
|
72 |
+
self.code.append(i) # 将包含 '%s' 的字符串的索引存储到 self.code 中
|
73 |
else:
|
74 |
encoded_str = input_str
|
75 |
encoded_inputs.append(encoded_str)
|
|
|
80 |
for i in self.code:
|
81 |
decoded_inputs[i] = decoded_inputs[i].replace('*', '%s') # 使用 self.code 中的索引还原原始字符串
|
82 |
return decoded_inputs
|
83 |
+
|
84 |
+
class ParenSParenFilter(Filter):
|
85 |
+
def __init__(self):
|
86 |
+
self.name = 'Paren s paren filter'
|
87 |
+
self.code = []
|
88 |
+
|
89 |
+
def encoder(self, inputs):
|
90 |
+
encoded_inputs = []
|
91 |
+
self.code = [] # 清空 self.code
|
92 |
+
for i, input_str in enumerate(inputs):
|
93 |
+
if '(s)' in input_str:
|
94 |
+
encoded_str = input_str.replace('(s)', '$')
|
95 |
+
self.code.append(i) # 将包含 '(s)' 的字符串的索引存储到 self.code 中
|
96 |
+
else:
|
97 |
+
encoded_str = input_str
|
98 |
+
encoded_inputs.append(encoded_str)
|
99 |
+
return encoded_inputs
|
100 |
+
|
101 |
+
def decoder(self, inputs):
|
102 |
+
decoded_inputs = inputs.copy()
|
103 |
+
for i in self.code:
|
104 |
+
decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)') # 使用 self.code 中的索引还原原始字符串
|
105 |
+
return decoded_inputs
|
106 |
|
107 |
class ChevronsFilter(Filter):
|
108 |
def __init__(self):
|
|
|
261 |
# self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
|
262 |
|
263 |
def generate(self, inputs, original_language, target_languages, max_batch_size):
|
264 |
+
filter_list = [SpecialTokenFilter(), SperSignFilter(), ParenSParenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
|
265 |
filter_pipeline = FilterPipeline(filter_list)
|
266 |
def language_mapping(original_language):
|
267 |
d = {
|