afri_app / modules /translation_model.py
SmokeyBandit's picture
Update modules/translation_model.py
8fe2cdc verified
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import logging
import os
class TranslationModel:
def __init__(self, cache_dir="models/"):
self.device = torch.device("cpu")
logging.info("Using CPU for translations")
# Ensure cache directory exists
os.makedirs(cache_dir, exist_ok=True)
model_name = "facebook/m2m100_418M" # Smaller model
try:
# Try to load from local cache first
self.tokenizer = M2M100Tokenizer.from_pretrained(
cache_dir,
local_files_only=True
)
self.model = M2M100ForConditionalGeneration.from_pretrained(
cache_dir,
local_files_only=True,
device_map="cpu",
low_cpu_mem_usage=True
)
except:
# If not in cache, download and save
self.tokenizer = M2M100Tokenizer.from_pretrained(model_name)
self.model = M2M100ForConditionalGeneration.from_pretrained(
model_name,
device_map="cpu",
low_cpu_mem_usage=True
)
# Save for offline use
self.tokenizer.save_pretrained(cache_dir)
self.model.save_pretrained(cache_dir)
self.model.eval()
def translate(self, text: str, source_lang: str, target_lang: str) -> str:
try:
self.tokenizer.src_lang = source_lang
encoded = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
generated = self.model.generate(
**encoded,
forced_bos_token_id=self.tokenizer.get_lang_id(target_lang),
max_length=128,
num_beams=2,
length_penalty=0.6
)
return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
except Exception as e:
logging.error(f"Translation error: {str(e)}")
return f"Translation error: {str(e)}"