Upload 2 files
Browse files- config.json +2 -2
- model.py +234 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"activation_dropout": 0.0,
|
4 |
"activation_function": "relu",
|
5 |
"architectures": [
|
@@ -28,7 +28,7 @@
|
|
28 |
"pad_token_id": 1,
|
29 |
"scale_embedding": true,
|
30 |
"torch_dtype": "float32",
|
31 |
-
"transformers_version": "4.
|
32 |
"use_cache": true,
|
33 |
"vocab_size": 256206
|
34 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "facebook/nllb-200-distilled-1.3B",
|
3 |
"activation_dropout": 0.0,
|
4 |
"activation_function": "relu",
|
5 |
"architectures": [
|
|
|
28 |
"pad_token_id": 1,
|
29 |
"scale_embedding": true,
|
30 |
"torch_dtype": "float32",
|
31 |
+
"transformers_version": "4.35.2",
|
32 |
"use_cache": true,
|
33 |
"vocab_size": 256206
|
34 |
}
|
model.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, pipeline
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import Type
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from modules.file import ExcelFileWriter
|
7 |
+
import os
|
8 |
+
|
9 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
|
11 |
+
|
12 |
+
class Model():
|
13 |
+
def __init__(self, modelname, selected_lora_model, selected_gpu):
|
14 |
+
def get_gpu_index(gpu_info, target_gpu_name):
|
15 |
+
"""
|
16 |
+
从 GPU 信息中获取目标 GPU 的索引
|
17 |
+
Args:
|
18 |
+
gpu_info (list): 包含 GPU 名称的列表
|
19 |
+
target_gpu_name (str): 目标 GPU 的名称
|
20 |
+
Returns:
|
21 |
+
int: 目标 GPU 的索引,如果未找到则返回 -1
|
22 |
+
"""
|
23 |
+
for i, name in enumerate(gpu_info):
|
24 |
+
if target_gpu_name.lower() in name.lower():
|
25 |
+
return i
|
26 |
+
return -1
|
27 |
+
if selected_gpu != "cpu":
|
28 |
+
gpu_count = torch.cuda.device_count()
|
29 |
+
gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
|
30 |
+
selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
|
31 |
+
self.device_name = f"cuda:{selected_gpu_index}"
|
32 |
+
else:
|
33 |
+
self.device_name = "cpu"
|
34 |
+
print("device_name", self.device_name)
|
35 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(modelname).to(self.device_name)
|
36 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
|
37 |
+
# self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
|
38 |
+
|
39 |
+
def generate(self, inputs, original_language, target_languages, max_batch_size):
|
40 |
+
def language_mapping(original_language):
|
41 |
+
d = {
|
42 |
+
"Achinese (Arabic script)": "ace_Arab",
|
43 |
+
"Achinese (Latin script)": "ace_Latn",
|
44 |
+
"Mesopotamian Arabic": "acm_Arab",
|
45 |
+
"Ta'izzi-Adeni Arabic": "acq_Arab",
|
46 |
+
"Tunisian Arabic": "aeb_Arab",
|
47 |
+
"Afrikaans": "afr_Latn",
|
48 |
+
"South Levantine Arabic": "ajp_Arab",
|
49 |
+
"Akan": "aka_Latn",
|
50 |
+
"Amharic": "amh_Ethi",
|
51 |
+
"North Levantine Arabic": "apc_Arab",
|
52 |
+
"Standard Arabic": "arb_Arab",
|
53 |
+
"Najdi Arabic": "ars_Arab",
|
54 |
+
"Moroccan Arabic": "ary_Arab",
|
55 |
+
"Egyptian Arabic": "arz_Arab",
|
56 |
+
"Assamese": "asm_Beng",
|
57 |
+
"Asturian": "ast_Latn",
|
58 |
+
"Awadhi": "awa_Deva",
|
59 |
+
"Central Aymara": "ayr_Latn",
|
60 |
+
"South Azerbaijani": "azb_Arab",
|
61 |
+
"North Azerbaijani": "azj_Latn",
|
62 |
+
"Bashkir": "bak_Cyrl",
|
63 |
+
"Bambara": "bam_Latn",
|
64 |
+
"Balinese": "ban_Latn",
|
65 |
+
"Belarusian": "bel_Cyrl",
|
66 |
+
"Bemba": "bem_Latn",
|
67 |
+
"Bengali": "ben_Beng",
|
68 |
+
"Bhojpuri": "bho_Deva",
|
69 |
+
"Banjar (Arabic script)": "bjn_Arab",
|
70 |
+
"Banjar (Latin script)": "bjn_Latn",
|
71 |
+
"Tibetan": "bod_Tibt",
|
72 |
+
"Bosnian": "bos_Latn",
|
73 |
+
"Buginese": "bug_Latn",
|
74 |
+
"Bulgarian": "bul_Cyrl",
|
75 |
+
"Catalan": "cat_Latn",
|
76 |
+
"Cebuano": "ceb_Latn",
|
77 |
+
"Czech": "ces_Latn",
|
78 |
+
"Chokwe": "cjk_Latn",
|
79 |
+
"Central Kurdish": "ckb_Arab",
|
80 |
+
"Crimean Tatar": "crh_Latn",
|
81 |
+
"Welsh": "cym_Latn",
|
82 |
+
"Danish": "dan_Latn",
|
83 |
+
"German": "deu_Latn",
|
84 |
+
"Dinka": "dik_Latn",
|
85 |
+
"Jula": "dyu_Latn",
|
86 |
+
"Dzongkha": "dzo_Tibt",
|
87 |
+
"Greek": "ell_Grek",
|
88 |
+
"English": "eng_Latn",
|
89 |
+
"Esperanto": "epo_Latn",
|
90 |
+
"Estonian": "est_Latn",
|
91 |
+
"Basque": "eus_Latn",
|
92 |
+
"Ewe": "ewe_Latn",
|
93 |
+
"Faroese": "fao_Latn",
|
94 |
+
"Persian": "pes_Arab",
|
95 |
+
"Fijian": "fij_Latn",
|
96 |
+
"Finnish": "fin_Latn",
|
97 |
+
"Fon": "fon_Latn",
|
98 |
+
"French": "fra_Latn",
|
99 |
+
"Friulian": "fur_Latn",
|
100 |
+
"Nigerian Fulfulde": "fuv_Latn",
|
101 |
+
"Scottish Gaelic": "gla_Latn",
|
102 |
+
"Irish": "gle_Latn",
|
103 |
+
"Galician": "glg_Latn",
|
104 |
+
"Guarani": "grn_Latn",
|
105 |
+
"Gujarati": "guj_Gujr",
|
106 |
+
"Haitian Creole": "hat_Latn",
|
107 |
+
"Hausa": "hau_Latn",
|
108 |
+
"Hebrew": "heb_Hebr",
|
109 |
+
"Hindi": "hin_Deva",
|
110 |
+
"Chhattisgarhi": "hne_Deva",
|
111 |
+
"Croatian": "hrv_Latn",
|
112 |
+
"Hungarian": "hun_Latn",
|
113 |
+
"Armenian": "hye_Armn",
|
114 |
+
"Igbo": "ibo_Latn",
|
115 |
+
"Iloko": "ilo_Latn",
|
116 |
+
"Indonesian": "ind_Latn",
|
117 |
+
"Icelandic": "isl_Latn",
|
118 |
+
"Italian": "ita_Latn",
|
119 |
+
"Javanese": "jav_Latn",
|
120 |
+
"Japanese": "jpn_Jpan",
|
121 |
+
"Kabyle": "kab_Latn",
|
122 |
+
"Kachin": "kac_Latn",
|
123 |
+
"Arabic": "ar_AR",
|
124 |
+
"Chinese": "zho_Hans",
|
125 |
+
"Spanish": "spa_Latn",
|
126 |
+
"Dutch": "nld_Latn",
|
127 |
+
"Kazakh": "kaz_Cyrl",
|
128 |
+
"Korean": "kor_Hang",
|
129 |
+
"Lithuanian": "lit_Latn",
|
130 |
+
"Malayalam": "mal_Mlym",
|
131 |
+
"Marathi": "mar_Deva",
|
132 |
+
"Nepali": "ne_NP",
|
133 |
+
"Polish": "pol_Latn",
|
134 |
+
"Portuguese": "por_Latn",
|
135 |
+
"Russian": "rus_Cyrl",
|
136 |
+
"Sinhala": "sin_Sinh",
|
137 |
+
"Tamil": "tam_Taml",
|
138 |
+
"Turkish": "tur_Latn",
|
139 |
+
"Ukrainian": "ukr_Cyrl",
|
140 |
+
"Urdu": "urd_Arab",
|
141 |
+
"Vietnamese": "vie_Latn",
|
142 |
+
"Thai":"tha_Thai"
|
143 |
+
}
|
144 |
+
return d[original_language]
|
145 |
+
def process_gpu_translate_result(temp_outputs):
|
146 |
+
outputs = []
|
147 |
+
for temp_output in temp_outputs:
|
148 |
+
length = len(temp_output[0]["generated_translation"])
|
149 |
+
for i in range(length):
|
150 |
+
temp = []
|
151 |
+
for trans in temp_output:
|
152 |
+
temp.append({
|
153 |
+
"target_language": trans["target_language"],
|
154 |
+
"generated_translation": trans['generated_translation'][i],
|
155 |
+
})
|
156 |
+
outputs.append(temp)
|
157 |
+
excel_writer = ExcelFileWriter()
|
158 |
+
excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
|
159 |
+
self.tokenizer.src_lang = language_mapping(original_language)
|
160 |
+
if self.device_name == "cpu":
|
161 |
+
# Tokenize input
|
162 |
+
input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
|
163 |
+
output = []
|
164 |
+
for target_language in target_languages:
|
165 |
+
# Get language code for the target language
|
166 |
+
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
|
167 |
+
# Generate translation
|
168 |
+
generated_tokens = self.model.generate(
|
169 |
+
**input_ids,
|
170 |
+
forced_bos_token_id=target_lang_code,
|
171 |
+
max_length=128
|
172 |
+
)
|
173 |
+
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
174 |
+
# Append result to output
|
175 |
+
output.append({
|
176 |
+
"target_language": target_language,
|
177 |
+
"generated_translation": generated_translation,
|
178 |
+
})
|
179 |
+
outputs = []
|
180 |
+
length = len(output[0]["generated_translation"])
|
181 |
+
for i in range(length):
|
182 |
+
temp = []
|
183 |
+
for trans in output:
|
184 |
+
temp.append({
|
185 |
+
"target_language": trans["target_language"],
|
186 |
+
"generated_translation": trans['generated_translation'][i],
|
187 |
+
})
|
188 |
+
outputs.append(temp)
|
189 |
+
return outputs
|
190 |
+
else:
|
191 |
+
# 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
|
192 |
+
# max_batch_size = 10
|
193 |
+
# Ensure batch size is within model limits:
|
194 |
+
print("length of inputs: ",len(inputs))
|
195 |
+
batch_size = min(len(inputs), int(max_batch_size))
|
196 |
+
batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
|
197 |
+
print("length of batches size: ", len(batches))
|
198 |
+
temp_outputs = []
|
199 |
+
processed_num = 0
|
200 |
+
for index, batch in enumerate(batches):
|
201 |
+
# Tokenize input
|
202 |
+
input_ids = self.tokenizer(batch, return_tensors="pt", padding=True).to(self.device_name)
|
203 |
+
temp = []
|
204 |
+
for target_language in target_languages:
|
205 |
+
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
|
206 |
+
generated_tokens = self.model.generate(
|
207 |
+
**input_ids,
|
208 |
+
forced_bos_token_id=target_lang_code,
|
209 |
+
)
|
210 |
+
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
211 |
+
# Append result to output
|
212 |
+
temp.append({
|
213 |
+
"target_language": target_language,
|
214 |
+
"generated_translation": generated_translation,
|
215 |
+
})
|
216 |
+
input_ids.to('cpu')
|
217 |
+
del input_ids
|
218 |
+
temp_outputs.append(temp)
|
219 |
+
processed_num += len(batch)
|
220 |
+
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
|
221 |
+
print("Already processed number: ", len(temp_outputs))
|
222 |
+
process_gpu_translate_result(temp_outputs)
|
223 |
+
outputs = []
|
224 |
+
for temp_output in temp_outputs:
|
225 |
+
length = len(temp_output[0]["generated_translation"])
|
226 |
+
for i in range(length):
|
227 |
+
temp = []
|
228 |
+
for trans in temp_output:
|
229 |
+
temp.append({
|
230 |
+
"target_language": trans["target_language"],
|
231 |
+
"generated_translation": trans['generated_translation'][i],
|
232 |
+
})
|
233 |
+
outputs.append(temp)
|
234 |
+
return outputs
|