princepride commited on
Commit
d22894a
·
verified ·
1 Parent(s): e8a87b8

Upload 2 files

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. model.py +234 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "yonyou-sg/nllb-200-distilled-1.3B",
3
  "activation_dropout": 0.0,
4
  "activation_function": "relu",
5
  "architectures": [
@@ -28,7 +28,7 @@
28
  "pad_token_id": 1,
29
  "scale_embedding": true,
30
  "torch_dtype": "float32",
31
- "transformers_version": "4.40.0",
32
  "use_cache": true,
33
  "vocab_size": 256206
34
  }
 
1
  {
2
+ "_name_or_path": "facebook/nllb-200-distilled-1.3B",
3
  "activation_dropout": 0.0,
4
  "activation_function": "relu",
5
  "architectures": [
 
28
  "pad_token_id": 1,
29
  "scale_embedding": true,
30
  "torch_dtype": "float32",
31
+ "transformers_version": "4.35.2",
32
  "use_cache": true,
33
  "vocab_size": 256206
34
  }
model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline
2
+ from abc import ABC, abstractmethod
3
+ from typing import Type
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from modules.file import ExcelFileWriter
7
+ import os
8
+
9
+ script_dir = os.path.dirname(os.path.abspath(__file__))
10
+ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
11
+
12
+ class Model():
13
+ def __init__(self, modelname, selected_lora_model, selected_gpu):
14
+ def get_gpu_index(gpu_info, target_gpu_name):
15
+ """
16
+ 从 GPU 信息中获取目标 GPU 的索引
17
+ Args:
18
+ gpu_info (list): 包含 GPU 名称的列表
19
+ target_gpu_name (str): 目标 GPU 的名称
20
+ Returns:
21
+ int: 目标 GPU 的索引,如果未找到则返回 -1
22
+ """
23
+ for i, name in enumerate(gpu_info):
24
+ if target_gpu_name.lower() in name.lower():
25
+ return i
26
+ return -1
27
+ if selected_gpu != "cpu":
28
+ gpu_count = torch.cuda.device_count()
29
+ gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
30
+ selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
31
+ self.device_name = f"cuda:{selected_gpu_index}"
32
+ else:
33
+ self.device_name = "cpu"
34
+ print("device_name", self.device_name)
35
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
36
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
37
+ # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
38
+
39
+ def generate(self, inputs, original_language, target_languages, max_batch_size):
40
+ def language_mapping(original_language):
41
+ d = {
42
+ "Achinese (Arabic script)": "ace_Arab",
43
+ "Achinese (Latin script)": "ace_Latn",
44
+ "Mesopotamian Arabic": "acm_Arab",
45
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
46
+ "Tunisian Arabic": "aeb_Arab",
47
+ "Afrikaans": "afr_Latn",
48
+ "South Levantine Arabic": "ajp_Arab",
49
+ "Akan": "aka_Latn",
50
+ "Amharic": "amh_Ethi",
51
+ "North Levantine Arabic": "apc_Arab",
52
+ "Standard Arabic": "arb_Arab",
53
+ "Najdi Arabic": "ars_Arab",
54
+ "Moroccan Arabic": "ary_Arab",
55
+ "Egyptian Arabic": "arz_Arab",
56
+ "Assamese": "asm_Beng",
57
+ "Asturian": "ast_Latn",
58
+ "Awadhi": "awa_Deva",
59
+ "Central Aymara": "ayr_Latn",
60
+ "South Azerbaijani": "azb_Arab",
61
+ "North Azerbaijani": "azj_Latn",
62
+ "Bashkir": "bak_Cyrl",
63
+ "Bambara": "bam_Latn",
64
+ "Balinese": "ban_Latn",
65
+ "Belarusian": "bel_Cyrl",
66
+ "Bemba": "bem_Latn",
67
+ "Bengali": "ben_Beng",
68
+ "Bhojpuri": "bho_Deva",
69
+ "Banjar (Arabic script)": "bjn_Arab",
70
+ "Banjar (Latin script)": "bjn_Latn",
71
+ "Tibetan": "bod_Tibt",
72
+ "Bosnian": "bos_Latn",
73
+ "Buginese": "bug_Latn",
74
+ "Bulgarian": "bul_Cyrl",
75
+ "Catalan": "cat_Latn",
76
+ "Cebuano": "ceb_Latn",
77
+ "Czech": "ces_Latn",
78
+ "Chokwe": "cjk_Latn",
79
+ "Central Kurdish": "ckb_Arab",
80
+ "Crimean Tatar": "crh_Latn",
81
+ "Welsh": "cym_Latn",
82
+ "Danish": "dan_Latn",
83
+ "German": "deu_Latn",
84
+ "Dinka": "dik_Latn",
85
+ "Jula": "dyu_Latn",
86
+ "Dzongkha": "dzo_Tibt",
87
+ "Greek": "ell_Grek",
88
+ "English": "eng_Latn",
89
+ "Esperanto": "epo_Latn",
90
+ "Estonian": "est_Latn",
91
+ "Basque": "eus_Latn",
92
+ "Ewe": "ewe_Latn",
93
+ "Faroese": "fao_Latn",
94
+ "Persian": "pes_Arab",
95
+ "Fijian": "fij_Latn",
96
+ "Finnish": "fin_Latn",
97
+ "Fon": "fon_Latn",
98
+ "French": "fra_Latn",
99
+ "Friulian": "fur_Latn",
100
+ "Nigerian Fulfulde": "fuv_Latn",
101
+ "Scottish Gaelic": "gla_Latn",
102
+ "Irish": "gle_Latn",
103
+ "Galician": "glg_Latn",
104
+ "Guarani": "grn_Latn",
105
+ "Gujarati": "guj_Gujr",
106
+ "Haitian Creole": "hat_Latn",
107
+ "Hausa": "hau_Latn",
108
+ "Hebrew": "heb_Hebr",
109
+ "Hindi": "hin_Deva",
110
+ "Chhattisgarhi": "hne_Deva",
111
+ "Croatian": "hrv_Latn",
112
+ "Hungarian": "hun_Latn",
113
+ "Armenian": "hye_Armn",
114
+ "Igbo": "ibo_Latn",
115
+ "Iloko": "ilo_Latn",
116
+ "Indonesian": "ind_Latn",
117
+ "Icelandic": "isl_Latn",
118
+ "Italian": "ita_Latn",
119
+ "Javanese": "jav_Latn",
120
+ "Japanese": "jpn_Jpan",
121
+ "Kabyle": "kab_Latn",
122
+ "Kachin": "kac_Latn",
123
+ "Arabic": "ar_AR",
124
+ "Chinese": "zho_Hans",
125
+ "Spanish": "spa_Latn",
126
+ "Dutch": "nld_Latn",
127
+ "Kazakh": "kaz_Cyrl",
128
+ "Korean": "kor_Hang",
129
+ "Lithuanian": "lit_Latn",
130
+ "Malayalam": "mal_Mlym",
131
+ "Marathi": "mar_Deva",
132
+ "Nepali": "ne_NP",
133
+ "Polish": "pol_Latn",
134
+ "Portuguese": "por_Latn",
135
+ "Russian": "rus_Cyrl",
136
+ "Sinhala": "sin_Sinh",
137
+ "Tamil": "tam_Taml",
138
+ "Turkish": "tur_Latn",
139
+ "Ukrainian": "ukr_Cyrl",
140
+ "Urdu": "urd_Arab",
141
+ "Vietnamese": "vie_Latn",
142
+ "Thai":"tha_Thai"
143
+ }
144
+ return d[original_language]
145
+ def process_gpu_translate_result(temp_outputs):
146
+ outputs = []
147
+ for temp_output in temp_outputs:
148
+ length = len(temp_output[0]["generated_translation"])
149
+ for i in range(length):
150
+ temp = []
151
+ for trans in temp_output:
152
+ temp.append({
153
+ "target_language": trans["target_language"],
154
+ "generated_translation": trans['generated_translation'][i],
155
+ })
156
+ outputs.append(temp)
157
+ excel_writer = ExcelFileWriter()
158
+ excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
159
+ self.tokenizer.src_lang = language_mapping(original_language)
160
+ if self.device_name == "cpu":
161
+ # Tokenize input
162
+ input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
163
+ output = []
164
+ for target_language in target_languages:
165
+ # Get language code for the target language
166
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
167
+ # Generate translation
168
+ generated_tokens = self.model.generate(
169
+ **input_ids,
170
+ forced_bos_token_id=target_lang_code,
171
+ max_length=128
172
+ )
173
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
174
+ # Append result to output
175
+ output.append({
176
+ "target_language": target_language,
177
+ "generated_translation": generated_translation,
178
+ })
179
+ outputs = []
180
+ length = len(output[0]["generated_translation"])
181
+ for i in range(length):
182
+ temp = []
183
+ for trans in output:
184
+ temp.append({
185
+ "target_language": trans["target_language"],
186
+ "generated_translation": trans['generated_translation'][i],
187
+ })
188
+ outputs.append(temp)
189
+ return outputs
190
+ else:
191
+ # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
192
+ # max_batch_size = 10
193
+ # Ensure batch size is within model limits:
194
+ print("length of inputs: ",len(inputs))
195
+ batch_size = min(len(inputs), int(max_batch_size))
196
+ batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
197
+ print("length of batches size: ", len(batches))
198
+ temp_outputs = []
199
+ processed_num = 0
200
+ for index, batch in enumerate(batches):
201
+ # Tokenize input
202
+ input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
203
+ temp = []
204
+ for target_language in target_languages:
205
+ target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
206
+ generated_tokens = self.model.generate(
207
+ **input_ids,
208
+ forced_bos_token_id=target_lang_code,
209
+ )
210
+ generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
211
+ # Append result to output
212
+ temp.append({
213
+ "target_language": target_language,
214
+ "generated_translation": generated_translation,
215
+ })
216
+ input_ids.to('cpu')
217
+ del input_ids
218
+ temp_outputs.append(temp)
219
+ processed_num += len(batch)
220
+ if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
221
+ print("Already processed number: ", len(temp_outputs))
222
+ process_gpu_translate_result(temp_outputs)
223
+ outputs = []
224
+ for temp_output in temp_outputs:
225
+ length = len(temp_output[0]["generated_translation"])
226
+ for i in range(length):
227
+ temp = []
228
+ for trans in temp_output:
229
+ temp.append({
230
+ "target_language": trans["target_language"],
231
+ "generated_translation": trans['generated_translation'][i],
232
+ })
233
+ outputs.append(temp)
234
+ return outputs