from transformers import NllbTokenizer, AutoModelForSeq2SeqLM def create_tokenizer_with_new_lang(model_id, new_lang): """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """ tokenizer = NllbTokenizer.from_pretrained(model_id) old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder) tokenizer.lang_code_to_id[new_lang] = old_len-1 tokenizer.id_to_lang_code[old_len-1] = new_lang # always move "mask" to the last position tokenizer.fairseq_tokens_to_ids[""] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id) tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()} if new_lang not in tokenizer._additional_special_tokens: tokenizer._additional_special_tokens.append(new_lang) # clear the added token encoder; otherwise a new token may end up there by mistake tokenizer.added_tokens_encoder = {} tokenizer.added_tokens_decoder = {} return tokenizer class Translator: @classmethod def from_pretrained(cls, path, new_lang='moo_Latn'): # Does the model need adaptation or not? # model, tokenizer = create_model_with_new_lang( # model_id=path, # new_lang=new_lang, # similar_lang='deu_Latn' # ) tokenizer = create_tokenizer_with_new_lang(path, new_lang) model = AutoModelForSeq2SeqLM.from_pretrained(path) return Translator(model, tokenizer) def __init__(self, model, tokenizer) -> None: self.model = model self.tokenizer = tokenizer # self.model.cuda() def translate(self, text, src_lang='moo_Latn', tgt_lang='deu_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs): self.tokenizer.src_lang = src_lang self.tokenizer.tgt_lang = tgt_lang inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length) result = self.model.generate( **inputs.to(self.model.device), forced_bos_token_id=self.tokenizer.convert_tokens_to_ids(tgt_lang), max_new_tokens=int(a + b * inputs.input_ids.shape[1]), num_beams=num_beams, **kwargs ) return self.tokenizer.batch_decode(result, skip_special_tokens=True)