Spaces:
Runtime error
Runtime error
Update off_topic.py
Browse files- off_topic.py +18 -2
off_topic.py
CHANGED
|
@@ -21,7 +21,15 @@ class Translator:
|
|
| 21 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 22 |
model_id)
|
| 23 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
@property
|
| 27 |
def _language_code_mapper(self):
|
|
@@ -33,12 +41,20 @@ class Translator:
|
|
| 33 |
return {"en": "en",
|
| 34 |
"es": "es",
|
| 35 |
"pt": "pt"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
|
| 38 |
self.tokenizer.src_lang = self._language_code_mapper[src_lang]
|
| 39 |
inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
translated_tokens = self.model.generate(
|
| 41 |
-
**inputs, forced_bos_token_id=
|
| 42 |
)
|
| 43 |
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
| 44 |
|
|
|
|
| 21 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 22 |
model_id)
|
| 23 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def _bos_token_attr(self):
|
| 27 |
+
if hasattr(self.tokenizer, "get_lang_id"):
|
| 28 |
+
return self.tokenizer.get_lang_id
|
| 29 |
+
elif hasattr(self.tokenizer, "lang_code_to_id"):
|
| 30 |
+
return self.tokenizer.lang_code_to_id
|
| 31 |
+
else:
|
| 32 |
+
return
|
| 33 |
|
| 34 |
@property
|
| 35 |
def _language_code_mapper(self):
|
|
|
|
| 41 |
return {"en": "en",
|
| 42 |
"es": "es",
|
| 43 |
"pt": "pt"}
|
| 44 |
+
else:
|
| 45 |
+
return {"en": "eng",
|
| 46 |
+
"es": "spa",
|
| 47 |
+
"pt": "por"}
|
| 48 |
|
| 49 |
def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
|
| 50 |
self.tokenizer.src_lang = self._language_code_mapper[src_lang]
|
| 51 |
inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
|
| 52 |
+
if "opus" in self.model_id.lower():
|
| 53 |
+
forced_bos_token_id = None
|
| 54 |
+
else:
|
| 55 |
+
forced_bos_token_id = self._bos_token_attr[self._language_code_mapper["en"]]
|
| 56 |
translated_tokens = self.model.generate(
|
| 57 |
+
**inputs, forced_bos_token_id=forced_bos_token_id, max_length=max_length
|
| 58 |
)
|
| 59 |
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
| 60 |
|