import torch from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer import logging import os class TranslationModel: def __init__(self, cache_dir="models/"): self.device = torch.device("cpu") logging.info("Using CPU for translations") # Ensure cache directory exists os.makedirs(cache_dir, exist_ok=True) model_name = "facebook/m2m100_418M" # Smaller model try: # Try to load from local cache first self.tokenizer = M2M100Tokenizer.from_pretrained( cache_dir, local_files_only=True ) self.model = M2M100ForConditionalGeneration.from_pretrained( cache_dir, local_files_only=True, device_map="cpu", low_cpu_mem_usage=True ) except: # If not in cache, download and save self.tokenizer = M2M100Tokenizer.from_pretrained(model_name) self.model = M2M100ForConditionalGeneration.from_pretrained( model_name, device_map="cpu", low_cpu_mem_usage=True ) # Save for offline use self.tokenizer.save_pretrained(cache_dir) self.model.save_pretrained(cache_dir) self.model.eval() def translate(self, text: str, source_lang: str, target_lang: str) -> str: try: self.tokenizer.src_lang = source_lang encoded = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True) with torch.no_grad(): generated = self.model.generate( **encoded, forced_bos_token_id=self.tokenizer.get_lang_id(target_lang), max_length=128, num_beams=2, length_penalty=0.6 ) return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0] except Exception as e: logging.error(f"Translation error: {str(e)}") return f"Translation error: {str(e)}"