Spaces:
Runtime error
Runtime error
Update off_topic.py
Browse files- off_topic.py +18 -2
off_topic.py
CHANGED
@@ -21,7 +21,15 @@ class Translator:
|
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
22 |
model_id)
|
23 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
@property
|
27 |
def _language_code_mapper(self):
|
@@ -33,12 +41,20 @@ class Translator:
|
|
33 |
return {"en": "en",
|
34 |
"es": "es",
|
35 |
"pt": "pt"}
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
|
38 |
self.tokenizer.src_lang = self._language_code_mapper[src_lang]
|
39 |
inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
|
|
40 |
translated_tokens = self.model.generate(
|
41 |
-
**inputs, forced_bos_token_id=
|
42 |
)
|
43 |
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
44 |
|
|
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
22 |
model_id)
|
23 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
|
24 |
+
|
25 |
+
@property
|
26 |
+
def _bos_token_attr(self):
|
27 |
+
if hasattr(self.tokenizer, "get_lang_id"):
|
28 |
+
return self.tokenizer.get_lang_id
|
29 |
+
elif hasattr(self.tokenizer, "lang_code_to_id"):
|
30 |
+
return self.tokenizer.lang_code_to_id
|
31 |
+
else:
|
32 |
+
return
|
33 |
|
34 |
@property
|
35 |
def _language_code_mapper(self):
|
|
|
41 |
return {"en": "en",
|
42 |
"es": "es",
|
43 |
"pt": "pt"}
|
44 |
+
else:
|
45 |
+
return {"en": "eng",
|
46 |
+
"es": "spa",
|
47 |
+
"pt": "por"}
|
48 |
|
49 |
def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
|
50 |
self.tokenizer.src_lang = self._language_code_mapper[src_lang]
|
51 |
inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
|
52 |
+
if "opus" in self.model_id.lower():
|
53 |
+
forced_bos_token_id = None
|
54 |
+
else:
|
55 |
+
forced_bos_token_id = self._bos_token_attr[self._language_code_mapper["en"]]
|
56 |
translated_tokens = self.model.generate(
|
57 |
+
**inputs, forced_bos_token_id=forced_bos_token_id, max_length=max_length
|
58 |
)
|
59 |
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
60 |
|