Spaces:
Sleeping
Sleeping
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM | |
def create_tokenizer_with_new_lang(model_id, new_lang): | |
""" | |
Add a new language token to the tokenizer vocabulary | |
(this should be done each time after its initialization) | |
""" | |
tokenizer = NllbTokenizer.from_pretrained(model_id) | |
old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) | |
tokenizer.lang_code_to_id[new_lang] = old_len-1 | |
tokenizer.id_to_lang_code[old_len-1] = new_lang | |
# always move "mask" to the last position | |
tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset | |
tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) | |
tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} | |
if new_lang not in tokenizer._additional_special_tokens: | |
tokenizer._additional_special_tokens.append(new_lang) | |
# clear the added token encoder; otherwise a new token may end up there by mistake | |
tokenizer.added_tokens_encoder = {} | |
tokenizer.added_tokens_decoder = {} | |
return tokenizer | |
class Translator: | |
def from_pretrained(cls, path, new_lang='moo_Latn'): | |
# Does the model need adaptation or not? | |
# model, tokenizer = create_model_with_new_lang( | |
# model_id=path, | |
# new_lang=new_lang, | |
# similar_lang='deu_Latn' | |
# ) | |
tokenizer = create_tokenizer_with_new_lang(path, new_lang) | |
model = AutoModelForSeq2SeqLM.from_pretrained(path) | |
return Translator(model, tokenizer) | |
def __init__(self, model, tokenizer) -> None: | |
self.model = model | |
self.tokenizer = tokenizer | |
# self.model.cuda() | |
def translate(self, text, src_lang='moo_Latn', tgt_lang='deu_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs): | |
self.tokenizer.src_lang = src_lang | |
self.tokenizer.tgt_lang = tgt_lang | |
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length) | |
result = self.model.generate( | |
**inputs.to(self.model.device), | |
forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(tgt_lang), | |
max_new_tokens=int(a + b * inputs.input_ids.shape[1]), | |
num_beams=num_beams, | |
**kwargs | |
) | |
return self.tokenizer.batch_decode(result, skip_special_tokens=True) | |