princepride commited on
Commit
eb95586
·
verified ·
1 Parent(s): 9258014

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +352 -234
model.py CHANGED
@@ -1,235 +1,353 @@
1
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline
2
- from abc import ABC, abstractmethod
3
- from typing import Type
4
- import torch
5
- import torch.nn.functional as F
6
- from modules.file import ExcelFileWriter
7
- import os
8
-
9
- script_dir = os.path.dirname(os.path.abspath(__file__))
10
- parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
11
-
12
- class Model():
13
- def __init__(self, modelname, selected_lora_model, selected_gpu):
14
- def get_gpu_index(gpu_info, target_gpu_name):
15
- """
16
- 从 GPU 信息中获取目标 GPU 的索引
17
- Args:
18
- gpu_info (list): 包含 GPU 名称的列表
19
- target_gpu_name (str): 目标 GPU 的名称
20
-
21
- Returns:
22
- int: 目标 GPU 的索引,如果未找到则返回 -1
23
- """
24
- for i, name in enumerate(gpu_info):
25
- if target_gpu_name.lower() in name.lower():
26
- return i
27
- return -1
28
- if selected_gpu != "cpu":
29
- gpu_count = torch.cuda.device_count()
30
- gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
31
- selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
32
- self.device_name = f"cuda:{selected_gpu_index}"
33
- else:
34
- self.device_name = "cpu"
35
- print("device_name", self.device_name)
36
- self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
37
- self.tokenizer = AutoTokenizer.from_pretrained(modelname)
38
- # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
39
-
40
- def generate(self, inputs, original_language, target_languages, max_batch_size):
41
- def language_mapping(original_language):
42
- d = {
43
- "Achinese (Arabic script)": "ace_Arab",
44
- "Achinese (Latin script)": "ace_Latn",
45
- "Mesopotamian Arabic": "acm_Arab",
46
- "Ta'izzi-Adeni Arabic": "acq_Arab",
47
- "Tunisian Arabic": "aeb_Arab",
48
- "Afrikaans": "afr_Latn",
49
- "South Levantine Arabic": "ajp_Arab",
50
- "Akan": "aka_Latn",
51
- "Amharic": "amh_Ethi",
52
- "North Levantine Arabic": "apc_Arab",
53
- "Standard Arabic": "arb_Arab",
54
- "Najdi Arabic": "ars_Arab",
55
- "Moroccan Arabic": "ary_Arab",
56
- "Egyptian Arabic": "arz_Arab",
57
- "Assamese": "asm_Beng",
58
- "Asturian": "ast_Latn",
59
- "Awadhi": "awa_Deva",
60
- "Central Aymara": "ayr_Latn",
61
- "South Azerbaijani": "azb_Arab",
62
- "North Azerbaijani": "azj_Latn",
63
- "Bashkir": "bak_Cyrl",
64
- "Bambara": "bam_Latn",
65
- "Balinese": "ban_Latn",
66
- "Belarusian": "bel_Cyrl",
67
- "Bemba": "bem_Latn",
68
- "Bengali": "ben_Beng",
69
- "Bhojpuri": "bho_Deva",
70
- "Banjar (Arabic script)": "bjn_Arab",
71
- "Banjar (Latin script)": "bjn_Latn",
72
- "Tibetan": "bod_Tibt",
73
- "Bosnian": "bos_Latn",
74
- "Buginese": "bug_Latn",
75
- "Bulgarian": "bul_Cyrl",
76
- "Catalan": "cat_Latn",
77
- "Cebuano": "ceb_Latn",
78
- "Czech": "ces_Latn",
79
- "Chokwe": "cjk_Latn",
80
- "Central Kurdish": "ckb_Arab",
81
- "Crimean Tatar": "crh_Latn",
82
- "Welsh": "cym_Latn",
83
- "Danish": "dan_Latn",
84
- "German": "deu_Latn",
85
- "Dinka": "dik_Latn",
86
- "Jula": "dyu_Latn",
87
- "Dzongkha": "dzo_Tibt",
88
- "Greek": "ell_Grek",
89
- "English": "eng_Latn",
90
- "Esperanto": "epo_Latn",
91
- "Estonian": "est_Latn",
92
- "Basque": "eus_Latn",
93
- "Ewe": "ewe_Latn",
94
- "Faroese": "fao_Latn",
95
- "Persian": "pes_Arab",
96
- "Fijian": "fij_Latn",
97
- "Finnish": "fin_Latn",
98
- "Fon": "fon_Latn",
99
- "French": "fra_Latn",
100
- "Friulian": "fur_Latn",
101
- "Nigerian Fulfulde": "fuv_Latn",
102
- "Scottish Gaelic": "gla_Latn",
103
- "Irish": "gle_Latn",
104
- "Galician": "glg_Latn",
105
- "Guarani": "grn_Latn",
106
- "Gujarati": "guj_Gujr",
107
- "Haitian Creole": "hat_Latn",
108
- "Hausa": "hau_Latn",
109
- "Hebrew": "heb_Hebr",
110
- "Hindi": "hin_Deva",
111
- "Chhattisgarhi": "hne_Deva",
112
- "Croatian": "hrv_Latn",
113
- "Hungarian": "hun_Latn",
114
- "Armenian": "hye_Armn",
115
- "Igbo": "ibo_Latn",
116
- "Iloko": "ilo_Latn",
117
- "Indonesian": "ind_Latn",
118
- "Icelandic": "isl_Latn",
119
- "Italian": "ita_Latn",
120
- "Javanese": "jav_Latn",
121
- "Japanese": "jpn_Jpan",
122
- "Kabyle": "kab_Latn",
123
- "Kachin": "kac_Latn",
124
- "Arabic": "ar_AR",
125
- "Chinese": "zho_Hans",
126
- "Spanish": "spa_Latn",
127
- "Dutch": "nld_Latn",
128
- "Kazakh": "kaz_Cyrl",
129
- "Korean": "kor_Hang",
130
- "Lithuanian": "lit_Latn",
131
- "Malayalam": "mal_Mlym",
132
- "Marathi": "mar_Deva",
133
- "Nepali": "ne_NP",
134
- "Polish": "pol_Latn",
135
- "Portuguese": "por_Latn",
136
- "Russian": "rus_Cyrl",
137
- "Sinhala": "sin_Sinh",
138
- "Tamil": "tam_Taml",
139
- "Turkish": "tur_Latn",
140
- "Ukrainian": "ukr_Cyrl",
141
- "Urdu": "urd_Arab",
142
- "Vietnamese": "vie_Latn",
143
- "Thai":"tha_Thai"
144
- }
145
- return d[original_language]
146
- def process_gpu_translate_result(temp_outputs):
147
- outputs = []
148
- for temp_output in temp_outputs:
149
- length = len(temp_output[0]["generated_translation"])
150
- for i in range(length):
151
- temp = []
152
- for trans in temp_output:
153
- temp.append({
154
- "target_language": trans["target_language"],
155
- "generated_translation": trans['generated_translation'][i],
156
- })
157
- outputs.append(temp)
158
- excel_writer = ExcelFileWriter()
159
- excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
160
- self.tokenizer.src_lang = language_mapping(original_language)
161
- if self.device_name == "cpu":
162
- # Tokenize input
163
- input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
164
- output = []
165
- for target_language in target_languages:
166
- # Get language code for the target language
167
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
168
- # Generate translation
169
- generated_tokens = self.model.generate(
170
- **input_ids,
171
- forced_bos_token_id=target_lang_code,
172
- max_length=128
173
- )
174
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
175
- # Append result to output
176
- output.append({
177
- "target_language": target_language,
178
- "generated_translation": generated_translation,
179
- })
180
- outputs = []
181
- length = len(output[0]["generated_translation"])
182
- for i in range(length):
183
- temp = []
184
- for trans in output:
185
- temp.append({
186
- "target_language": trans["target_language"],
187
- "generated_translation": trans['generated_translation'][i],
188
- })
189
- outputs.append(temp)
190
- return outputs
191
- else:
192
- # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
193
- # max_batch_size = 10
194
- # Ensure batch size is within model limits:
195
- print("length of inputs: ",len(inputs))
196
- batch_size = min(len(inputs), int(max_batch_size))
197
- batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
198
- print("length of batches size: ", len(batches))
199
- temp_outputs = []
200
- processed_num = 0
201
- for index, batch in enumerate(batches):
202
- # Tokenize input
203
- input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
204
- temp = []
205
- for target_language in target_languages:
206
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
207
- generated_tokens = self.model.generate(
208
- **input_ids,
209
- forced_bos_token_id=target_lang_code,
210
- )
211
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
212
- # Append result to output
213
- temp.append({
214
- "target_language": target_language,
215
- "generated_translation": generated_translation,
216
- })
217
- input_ids.to('cpu')
218
- del input_ids
219
- temp_outputs.append(temp)
220
- processed_num += len(batch)
221
- if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
222
- print("Already processed number: ", len(temp_outputs))
223
- process_gpu_translate_result(temp_outputs)
224
- outputs = []
225
- for temp_output in temp_outputs:
226
- length = len(temp_output[0]["generated_translation"])
227
- for i in range(length):
228
- temp = []
229
- for trans in temp_output:
230
- temp.append({
231
- "target_language": trans["target_language"],
232
- "generated_translation": trans['generated_translation'][i],
233
- })
234
- outputs.append(temp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  return outputs
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import torch
3
+ from modules.file import ExcelFileWriter
4
+ import os
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import List
8
+ import re
9
+
10
+ class FilterPipeline():
11
+ def __init__(self, filter_list):
12
+ self._filter_list:List[Filter] = filter_list
13
+
14
+ def append(self, filter):
15
+ self._filter_list.append(filter)
16
+
17
+ def batch_encoder(self, inputs):
18
+ for filter in self._filter_list:
19
+ inputs = filter.encoder(inputs)
20
+ return inputs
21
+
22
+ def batch_decoder(self, inputs):
23
+ for filter in reversed(self._filter_list):
24
+ inputs = filter.decoder(inputs)
25
+ return inputs
26
+
27
+ class Filter(ABC):
28
+ def __init__(self):
29
+ self.name = 'filter'
30
+ self.code = []
31
+ @abstractmethod
32
+ def encoder(self, inputs):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def decoder(self, inputs):
37
+ pass
38
+
39
+ class SpecialTokenFilter(Filter):
40
+ def __init__(self):
41
+ self.name = 'special token filter'
42
+ self.code = []
43
+ self.special_tokens = ['!', '!']
44
+
45
+ def encoder(self, inputs):
46
+ filtered_inputs = []
47
+ self.code = []
48
+ for i, input_str in enumerate(inputs):
49
+ if not all(char in self.special_tokens for char in input_str):
50
+ filtered_inputs.append(input_str)
51
+ else:
52
+ self.code.append([i, input_str])
53
+ return filtered_inputs
54
+
55
+ def decoder(self, inputs):
56
+ original_inputs = inputs.copy()
57
+ for removed_indice in self.code:
58
+ original_inputs.insert(removed_indice[0], removed_indice[1])
59
+ return original_inputs
60
+
61
+ class SperSignFilter(Filter):
62
+ def __init__(self):
63
+ self.name = 's persign filter'
64
+ self.code = []
65
+
66
+ def encoder(self, inputs):
67
+ encoded_inputs = []
68
+ self.code = [] # 清空 self.code
69
+ for i, input_str in enumerate(inputs):
70
+ if 's%' in input_str:
71
+ encoded_str = input_str.replace('s%', '*')
72
+ self.code.append(i) # 将包含 's%' 的字符串的索引存储到 self.code 中
73
+ else:
74
+ encoded_str = input_str
75
+ encoded_inputs.append(encoded_str)
76
+ return encoded_inputs
77
+
78
+ def decoder(self, inputs):
79
+ decoded_inputs = inputs.copy()
80
+ for i in self.code:
81
+ decoded_inputs[i] = decoded_inputs[i].replace('*', 's%') # 使用 self.code 中的索引还原原始字符串
82
+ return decoded_inputs
83
+
84
+ class SimilarFilter(Filter):
85
+ def __init__(self):
86
+ self.name = 'similar filter'
87
+ self.code = []
88
+
89
+ def is_similar(self, str1, str2):
90
+ # 判断两个字符串是否相似(只有数字上有区别)
91
+ pattern = re.compile(r'\d+')
92
+ return pattern.sub('', str1) == pattern.sub('', str2)
93
+
94
+ def encoder(self, inputs):
95
+ encoded_inputs = []
96
+ self.code = [] # 清空 self.code
97
+ i = 0
98
+ while i < len(inputs):
99
+ encoded_inputs.append(inputs[i])
100
+ similar_strs = [inputs[i]]
101
+ j = i + 1
102
+ while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
103
+ similar_strs.append(inputs[j])
104
+ j += 1
105
+ if len(similar_strs) > 1:
106
+ self.code.append((i, similar_strs)) # 将相似字符串的起始索引和实际字符串列表存储到 self.code 中
107
+ i = j
108
+ return encoded_inputs
109
+
110
+ def decoder(self, inputs):
111
+ decoded_inputs = []
112
+ index = 0
113
+ for i, similar_strs in self.code:
114
+ decoded_inputs.extend(inputs[index:i])
115
+ decoded_inputs.extend(similar_strs) # 直接将实际的相似字符串添加到 decoded_inputs 中
116
+ index = i + 1
117
+ decoded_inputs.extend(inputs[index:])
118
+ return decoded_inputs
119
+
120
+ script_dir = os.path.dirname(os.path.abspath(__file__))
121
+ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
122
+
123
+ class Model():
124
+ def __init__(self, modelname, selected_lora_model, selected_gpu):
125
+ def get_gpu_index(gpu_info, target_gpu_name):
126
+ """
127
+ GPU 信息中获取目标 GPU 的索引
128
+ Args:
129
+ gpu_info (list): 包含 GPU 名称的列表
130
+ target_gpu_name (str): 目标 GPU 的名称
131
+
132
+ Returns:
133
+ int: 目标 GPU 的索引,如果未找到则返回 -1
134
+ """
135
+ for i, name in enumerate(gpu_info):
136
+ if target_gpu_name.lower() in name.lower():
137
+ return i
138
+ return -1
139
+ if selected_gpu != "cpu":
140
+ gpu_count = torch.cuda.device_count()
141
+ gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
142
+ selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
143
+ self.device_name = f"cuda:{selected_gpu_index}"
144
+ else:
145
+ self.device_name = "cpu"
146
+ print("device_name", self.device_name)
147
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
148
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
149
+ # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
150
+
151
+ def generate(self, inputs, original_language, target_languages, max_batch_size):
152
+ filter_list = [SpecialTokenFilter(), SperSignFilter(), SimilarFilter()]
153
+ filter_pipeline = FilterPipeline(filter_list)
154
+ def language_mapping(original_language):
155
+ d = {
156
+ "Achinese (Arabic script)": "ace_Arab",
157
+ "Achinese (Latin script)": "ace_Latn",
158
+ "Mesopotamian Arabic": "acm_Arab",
159
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
160
+ "Tunisian Arabic": "aeb_Arab",
161
+ "Afrikaans": "afr_Latn",
162
+ "South Levantine Arabic": "ajp_Arab",
163
+ "Akan": "aka_Latn",
164
+ "Amharic": "amh_Ethi",
165
+ "North Levantine Arabic": "apc_Arab",
166
+ "Standard Arabic": "arb_Arab",
167
+ "Najdi Arabic": "ars_Arab",
168
+ "Moroccan Arabic": "ary_Arab",
169
+ "Egyptian Arabic": "arz_Arab",
170
+ "Assamese": "asm_Beng",
171
+ "Asturian": "ast_Latn",
172
+ "Awadhi": "awa_Deva",
173
+ "Central Aymara": "ayr_Latn",
174
+ "South Azerbaijani": "azb_Arab",
175
+ "North Azerbaijani": "azj_Latn",
176
+ "Bashkir": "bak_Cyrl",
177
+ "Bambara": "bam_Latn",
178
+ "Balinese": "ban_Latn",
179
+ "Belarusian": "bel_Cyrl",
180
+ "Bemba": "bem_Latn",
181
+ "Bengali": "ben_Beng",
182
+ "Bhojpuri": "bho_Deva",
183
+ "Banjar (Arabic script)": "bjn_Arab",
184
+ "Banjar (Latin script)": "bjn_Latn",
185
+ "Tibetan": "bod_Tibt",
186
+ "Bosnian": "bos_Latn",
187
+ "Buginese": "bug_Latn",
188
+ "Bulgarian": "bul_Cyrl",
189
+ "Catalan": "cat_Latn",
190
+ "Cebuano": "ceb_Latn",
191
+ "Czech": "ces_Latn",
192
+ "Chokwe": "cjk_Latn",
193
+ "Central Kurdish": "ckb_Arab",
194
+ "Crimean Tatar": "crh_Latn",
195
+ "Welsh": "cym_Latn",
196
+ "Danish": "dan_Latn",
197
+ "German": "deu_Latn",
198
+ "Dinka": "dik_Latn",
199
+ "Jula": "dyu_Latn",
200
+ "Dzongkha": "dzo_Tibt",
201
+ "Greek": "ell_Grek",
202
+ "English": "eng_Latn",
203
+ "Esperanto": "epo_Latn",
204
+ "Estonian": "est_Latn",
205
+ "Basque": "eus_Latn",
206
+ "Ewe": "ewe_Latn",
207
+ "Faroese": "fao_Latn",
208
+ "Persian": "pes_Arab",
209
+ "Fijian": "fij_Latn",
210
+ "Finnish": "fin_Latn",
211
+ "Fon": "fon_Latn",
212
+ "French": "fra_Latn",
213
+ "Friulian": "fur_Latn",
214
+ "Nigerian Fulfulde": "fuv_Latn",
215
+ "Scottish Gaelic": "gla_Latn",
216
+ "Irish": "gle_Latn",
217
+ "Galician": "glg_Latn",
218
+ "Guarani": "grn_Latn",
219
+ "Gujarati": "guj_Gujr",
220
+ "Haitian Creole": "hat_Latn",
221
+ "Hausa": "hau_Latn",
222
+ "Hebrew": "heb_Hebr",
223
+ "Hindi": "hin_Deva",
224
+ "Chhattisgarhi": "hne_Deva",
225
+ "Croatian": "hrv_Latn",
226
+ "Hungarian": "hun_Latn",
227
+ "Armenian": "hye_Armn",
228
+ "Igbo": "ibo_Latn",
229
+ "Iloko": "ilo_Latn",
230
+ "Indonesian": "ind_Latn",
231
+ "Icelandic": "isl_Latn",
232
+ "Italian": "ita_Latn",
233
+ "Javanese": "jav_Latn",
234
+ "Japanese": "jpn_Jpan",
235
+ "Kabyle": "kab_Latn",
236
+ "Kachin": "kac_Latn",
237
+ "Arabic": "ar_AR",
238
+ "Chinese": "zho_Hans",
239
+ "Spanish": "spa_Latn",
240
+ "Dutch": "nld_Latn",
241
+ "Kazakh": "kaz_Cyrl",
242
+ "Korean": "kor_Hang",
243
+ "Lithuanian": "lit_Latn",
244
+ "Malayalam": "mal_Mlym",
245
+ "Marathi": "mar_Deva",
246
+ "Nepali": "ne_NP",
247
+ "Polish": "pol_Latn",
248
+ "Portuguese": "por_Latn",
249
+ "Russian": "rus_Cyrl",
250
+ "Sinhala": "sin_Sinh",
251
+ "Tamil": "tam_Taml",
252
+ "Turkish": "tur_Latn",
253
+ "Ukrainian": "ukr_Cyrl",
254
+ "Urdu": "urd_Arab",
255
+ "Vietnamese": "vie_Latn",
256
+ "Thai":"tha_Thai"
257
+ }
258
+ return d[original_language]
259
+ def process_gpu_translate_result(temp_outputs):
260
+ outputs = []
261
+ for temp_output in temp_outputs:
262
+ length = len(temp_output[0]["generated_translation"])
263
+ for i in range(length):
264
+ temp = []
265
+ for trans in temp_output:
266
+ temp.append({
267
+ "target_language": trans["target_language"],
268
+ "generated_translation": trans['generated_translation'][i],
269
+ })
270
+ outputs.append(temp)
271
+ excel_writer = ExcelFileWriter()
272
+ excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
273
+ self.tokenizer.src_lang = language_mapping(original_language)
274
+ if self.device_name == "cpu":
275
+ # Tokenize input
276
+ input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
277
+ output = []
278
+ for target_language in target_languages:
279
+ # Get language code for the target language
280
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
281
+ # Generate translation
282
+ generated_tokens = self.model.generate(
283
+ **input_ids,
284
+ forced_bos_token_id=target_lang_code,
285
+ max_length=128
286
+ )
287
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
288
+ # Append result to output
289
+ output.append({
290
+ "target_language": target_language,
291
+ "generated_translation": generated_translation,
292
+ })
293
+ outputs = []
294
+ length = len(output[0]["generated_translation"])
295
+ for i in range(length):
296
+ temp = []
297
+ for trans in output:
298
+ temp.append({
299
+ "target_language": trans["target_language"],
300
+ "generated_translation": trans['generated_translation'][i],
301
+ })
302
+ outputs.append(temp)
303
+ return outputs
304
+ else:
305
+ # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
306
+ # max_batch_size = 10
307
+ # Ensure batch size is within model limits:
308
+ print("length of inputs: ",len(inputs))
309
+ batch_size = min(len(inputs), int(max_batch_size))
310
+ batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
311
+ print("length of batches size: ", len(batches))
312
+ temp_outputs = []
313
+ processed_num = 0
314
+ for index, batch in enumerate(batches):
315
+ # Tokenize input
316
+ batch = filter_pipeline.batch_encoder(batch)
317
+ print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
318
+ print(batch)
319
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
320
+ temp = []
321
+ for target_language in target_languages:
322
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
323
+ generated_tokens = self.model.generate(
324
+ **input_ids,
325
+ forced_bos_token_id=target_lang_code,
326
+ )
327
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
328
+ print(generated_translation)
329
+ generated_translation = filter_pipeline.batch_decoder(generated_translation)
330
+ # Append result to output
331
+ temp.append({
332
+ "target_language": target_language,
333
+ "generated_translation": generated_translation,
334
+ })
335
+ input_ids.to('cpu')
336
+ del input_ids
337
+ temp_outputs.append(temp)
338
+ processed_num += len(batch)
339
+ if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
340
+ print("Already processed number: ", len(temp_outputs))
341
+ process_gpu_translate_result(temp_outputs)
342
+ outputs = []
343
+ for temp_output in temp_outputs:
344
+ length = len(temp_output[0]["generated_translation"])
345
+ for i in range(length):
346
+ temp = []
347
+ for trans in temp_output:
348
+ temp.append({
349
+ "target_language": trans["target_language"],
350
+ "generated_translation": trans['generated_translation'][i],
351
+ })
352
+ outputs.append(temp)
353
  return outputs