SmokeyBandit commited on
Commit
550acbc
Β·
verified Β·
1 Parent(s): 89da072

Update modules/translation_model.py

Browse files
Files changed (1) hide show
  1. modules/translation_model.py +15 -8
modules/translation_model.py CHANGED
@@ -4,27 +4,34 @@ import logging
4
 
5
  class TranslationModel:
6
  def __init__(self, cache_dir="models/"):
7
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
- if self.device.type == "cpu":
9
- logging.warning("GPU not found, using CPU. Translation will be slower.")
10
 
11
  self.model_name = "facebook/m2m100_1.2B"
12
- self.tokenizer = M2M100Tokenizer.from_pretrained(self.model_name, cache_dir=cache_dir)
 
 
 
 
13
  self.model = M2M100ForConditionalGeneration.from_pretrained(
14
  self.model_name,
15
- cache_dir=cache_dir
16
- ).to(self.device)
 
 
 
17
  self.model.eval()
18
 
19
  def translate(self, text: str, source_lang: str, target_lang: str) -> str:
20
  try:
21
  self.tokenizer.src_lang = source_lang
22
- encoded = self.tokenizer(text, return_tensors="pt").to(self.device)
23
 
24
  with torch.no_grad():
25
  generated = self.model.generate(
26
  **encoded,
27
- forced_bos_token_id=self.tokenizer.get_lang_id(target_lang)
 
28
  )
29
 
30
  return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
 
4
 
5
  class TranslationModel:
6
  def __init__(self, cache_dir="models/"):
7
+ self.device = torch.device("cpu")
8
+ logging.info("Using CPU for translations")
 
9
 
10
  self.model_name = "facebook/m2m100_1.2B"
11
+ self.tokenizer = M2M100Tokenizer.from_pretrained(
12
+ self.model_name,
13
+ cache_dir=cache_dir,
14
+ local_files_only=True # Only use cached files
15
+ )
16
  self.model = M2M100ForConditionalGeneration.from_pretrained(
17
  self.model_name,
18
+ cache_dir=cache_dir,
19
+ local_files_only=True,
20
+ device_map="cpu",
21
+ low_cpu_mem_usage=True
22
+ )
23
  self.model.eval()
24
 
25
  def translate(self, text: str, source_lang: str, target_lang: str) -> str:
26
  try:
27
  self.tokenizer.src_lang = source_lang
28
+ encoded = self.tokenizer(text, return_tensors="pt")
29
 
30
  with torch.no_grad():
31
  generated = self.model.generate(
32
  **encoded,
33
+ forced_bos_token_id=self.tokenizer.get_lang_id(target_lang),
34
+ max_length=128
35
  )
36
 
37
  return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0]