Spaces:
Runtime error
Runtime error
| """ | |
| BackTranslation class | |
| ----------------------------------- | |
| """ | |
| import random | |
| from transformers import MarianMTModel, MarianTokenizer | |
| from textattack.shared import AttackedText | |
| from .sentence_transformation import SentenceTransformation | |
| class BackTranslation(SentenceTransformation): | |
| """A type of sentence level transformation that takes in a text input, | |
| translates it into target language and translates it back to source | |
| language. | |
| letters_to_insert (string): letters allowed for insertion into words | |
| (used by some char-based transformations) | |
| src_lang (string): source language | |
| target_lang (string): target language, for the list of supported language check bottom of this page | |
| src_model: translation model from huggingface that translates from source language to target language | |
| target_model: translation model from huggingface that translates from target language to source language | |
| chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en) | |
| Example:: | |
| >>> from textattack.transformations.sentence_transformations import BackTranslation | |
| >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification | |
| >>> from textattack.augmentation import Augmenter | |
| >>> transformation = BackTranslation() | |
| >>> constraints = [RepeatModification(), StopwordModification()] | |
| >>> augmenter = Augmenter(transformation = transformation, constraints = constraints) | |
| >>> s = 'What on earth are you doing here.' | |
| >>> augmenter.augment(s) | |
| """ | |
| def __init__( | |
| self, | |
| src_lang="en", | |
| target_lang="es", | |
| src_model="Helsinki-NLP/opus-mt-ROMANCE-en", | |
| target_model="Helsinki-NLP/opus-mt-en-ROMANCE", | |
| chained_back_translation=0, | |
| ): | |
| self.src_lang = src_lang | |
| self.target_lang = target_lang | |
| self.target_model = MarianMTModel.from_pretrained(target_model) | |
| self.target_tokenizer = MarianTokenizer.from_pretrained(target_model) | |
| self.src_model = MarianMTModel.from_pretrained(src_model) | |
| self.src_tokenizer = MarianTokenizer.from_pretrained(src_model) | |
| self.chained_back_translation = chained_back_translation | |
| def translate(self, input, model, tokenizer, lang="es"): | |
| # change the text to model's format | |
| src_texts = [] | |
| if lang == "en": | |
| src_texts.append(input[0]) | |
| else: | |
| if ">>" and "<<" not in lang: | |
| lang = ">>" + lang + "<< " | |
| src_texts.append(lang + input[0]) | |
| # tokenize the input | |
| encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt") | |
| # translate the input | |
| translated = model.generate(**encoded_input) | |
| translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True) | |
| return translated_input | |
| def _get_transformations(self, current_text, indices_to_modify): | |
| transformed_texts = [] | |
| current_text = current_text.text | |
| # to perform chained back translation, a random list of target languages are selected from the provided model | |
| if self.chained_back_translation: | |
| list_of_target_lang = random.sample( | |
| self.target_tokenizer.supported_language_codes, | |
| self.chained_back_translation, | |
| ) | |
| for target_lang in list_of_target_lang: | |
| target_language_text = self.translate( | |
| [current_text], | |
| self.target_model, | |
| self.target_tokenizer, | |
| target_lang, | |
| ) | |
| src_language_text = self.translate( | |
| target_language_text, | |
| self.src_model, | |
| self.src_tokenizer, | |
| self.src_lang, | |
| ) | |
| current_text = src_language_text[0] | |
| return [AttackedText(current_text)] | |
| # translates source to target language and back to source language (single back translation) | |
| target_language_text = self.translate( | |
| [current_text], self.target_model, self.target_tokenizer, self.target_lang | |
| ) | |
| src_language_text = self.translate( | |
| target_language_text, self.src_model, self.src_tokenizer, self.src_lang | |
| ) | |
| transformed_texts.append(AttackedText(src_language_text[0])) | |
| return transformed_texts | |
| """ | |
| List of supported languages | |
| ['fr', | |
| 'es', | |
| 'it', | |
| 'pt', | |
| 'pt_br', | |
| 'ro', | |
| 'ca', | |
| 'gl', | |
| 'pt_BR<<', | |
| 'la<<', | |
| 'wa<<', | |
| 'fur<<', | |
| 'oc<<', | |
| 'fr_CA<<', | |
| 'sc<<', | |
| 'es_ES', | |
| 'es_MX', | |
| 'es_AR', | |
| 'es_PR', | |
| 'es_UY', | |
| 'es_CL', | |
| 'es_CO', | |
| 'es_CR', | |
| 'es_GT', | |
| 'es_HN', | |
| 'es_NI', | |
| 'es_PA', | |
| 'es_PE', | |
| 'es_VE', | |
| 'es_DO', | |
| 'es_EC', | |
| 'es_SV', | |
| 'an', | |
| 'pt_PT', | |
| 'frp', | |
| 'lad', | |
| 'vec', | |
| 'fr_FR', | |
| 'co', | |
| 'it_IT', | |
| 'lld', | |
| 'lij', | |
| 'lmo', | |
| 'nap', | |
| 'rm', | |
| 'scn', | |
| 'mwl'] | |
| """ | |