SmokeyBandit commited on
Commit
8fe2cdc
Β·
verified Β·
1 Parent(s): 30c9349

Update modules/translation_model.py

Browse files
Files changed (1) hide show
  1. modules/translation_model.py +35 -16
modules/translation_model.py CHANGED
@@ -1,39 +1,58 @@
1
  import torch
2
- from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
3
  import logging
 
4
 
5
  class TranslationModel:
6
  def __init__(self, cache_dir="models/"):
7
  self.device = torch.device("cpu")
8
  logging.info("Using CPU for translations")
9
 
10
- self.model_name = "facebook/m2m100_1.2B"
11
- self.tokenizer = M2M100Tokenizer.from_pretrained(
12
- self.model_name,
13
- cache_dir=cache_dir,
14
- local_files_only=True # Only use cached files
15
- )
16
- self.model = M2M100ForConditionalGeneration.from_pretrained(
17
- self.model_name,
18
- cache_dir=cache_dir,
19
- local_files_only=True,
20
- device_map="cpu",
21
- low_cpu_mem_usage=True
22
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  self.model.eval()
24
 
25
  def translate(self, text: str, source_lang: str, target_lang: str) -> str:
26
  try:
27
  self.tokenizer.src_lang = source_lang
28
- encoded = self.tokenizer(text, return_tensors="pt")
29
 
30
  with torch.no_grad():
31
  generated = self.model.generate(
32
  **encoded,
33
  forced_bos_token_id=self.tokenizer.get_lang_id(target_lang),
34
- max_length=128
 
 
35
  )
36
 
37
  return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
38
  except Exception as e:
 
39
  return f"Translation error: {str(e)}"
 
1
  import torch
2
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
3
  import logging
4
+ import os
5
 
6
  class TranslationModel:
7
  def __init__(self, cache_dir="models/"):
8
  self.device = torch.device("cpu")
9
  logging.info("Using CPU for translations")
10
 
11
+ # Ensure cache directory exists
12
+ os.makedirs(cache_dir, exist_ok=True)
13
+
14
+ model_name = "facebook/m2m100_418M" # Smaller model
15
+ try:
16
+ # Try to load from local cache first
17
+ self.tokenizer = M2M100Tokenizer.from_pretrained(
18
+ cache_dir,
19
+ local_files_only=True
20
+ )
21
+ self.model = M2M100ForConditionalGeneration.from_pretrained(
22
+ cache_dir,
23
+ local_files_only=True,
24
+ device_map="cpu",
25
+ low_cpu_mem_usage=True
26
+ )
27
+ except:
28
+ # If not in cache, download and save
29
+ self.tokenizer = M2M100Tokenizer.from_pretrained(model_name)
30
+ self.model = M2M100ForConditionalGeneration.from_pretrained(
31
+ model_name,
32
+ device_map="cpu",
33
+ low_cpu_mem_usage=True
34
+ )
35
+ # Save for offline use
36
+ self.tokenizer.save_pretrained(cache_dir)
37
+ self.model.save_pretrained(cache_dir)
38
+
39
  self.model.eval()
40
 
41
  def translate(self, text: str, source_lang: str, target_lang: str) -> str:
42
  try:
43
  self.tokenizer.src_lang = source_lang
44
+ encoded = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
45
 
46
  with torch.no_grad():
47
  generated = self.model.generate(
48
  **encoded,
49
  forced_bos_token_id=self.tokenizer.get_lang_id(target_lang),
50
+ max_length=128,
51
+ num_beams=2,
52
+ length_penalty=0.6
53
  )
54
 
55
  return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
56
  except Exception as e:
57
+ logging.error(f"Translation error: {str(e)}")
58
  return f"Translation error: {str(e)}"