rxavier commited on
Commit
93c8d7a
·
1 Parent(s): ab7d145

Update off_topic.py

Browse files
Files changed (1) hide show
  1. off_topic.py +18 -2
off_topic.py CHANGED
@@ -21,7 +21,15 @@ class Translator:
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
  model_id)
23
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
24
- self.bos_token_map = self.tokenizer.get_lang_id if hasattr(self.tokenizer, "get_lang_id") else self.tokenizer.lang_code_to_id
 
 
 
 
 
 
 
 
25
 
26
  @property
27
  def _language_code_mapper(self):
@@ -33,12 +41,20 @@ class Translator:
33
  return {"en": "en",
34
  "es": "es",
35
  "pt": "pt"}
 
 
 
 
36
 
37
  def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
38
  self.tokenizer.src_lang = self._language_code_mapper[src_lang]
39
  inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
 
 
 
 
40
  translated_tokens = self.model.generate(
41
- **inputs, forced_bos_token_id=self.bos_token_map["eng_Latn"], max_length=max_length
42
  )
43
  return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
44
 
 
21
  self.tokenizer = AutoTokenizer.from_pretrained(
22
  model_id)
23
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(self.device)
24
+
25
+ @property
26
+ def _bos_token_attr(self):
27
+ if hasattr(self.tokenizer, "get_lang_id"):
28
+ return self.tokenizer.get_lang_id
29
+ elif hasattr(self.tokenizer, "lang_code_to_id"):
30
+ return self.tokenizer.lang_code_to_id
31
+ else:
32
+ return
33
 
34
  @property
35
  def _language_code_mapper(self):
 
41
  return {"en": "en",
42
  "es": "es",
43
  "pt": "pt"}
44
+ else:
45
+ return {"en": "eng",
46
+ "es": "spa",
47
+ "pt": "por"}
48
 
49
  def translate(self, texts: List[str], src_lang: str, dest_lang: str = "en", max_length: int = 100):
50
  self.tokenizer.src_lang = self._language_code_mapper[src_lang]
51
  inputs = self.tokenizer(texts, return_tensors="pt").to(self.device)
52
+ if "opus" in self.model_id.lower():
53
+ forced_bos_token_id = None
54
+ else:
55
+ forced_bos_token_id = self._bos_token_attr[self._language_code_mapper["en"]]
56
  translated_tokens = self.model.generate(
57
+ **inputs, forced_bos_token_id=forced_bos_token_id, max_length=max_length
58
  )
59
  return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
60