Spaces:
Runtime error
Runtime error
import torch | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
import logging | |
import os | |
class TranslationModel: | |
def __init__(self, cache_dir="models/"): | |
self.device = torch.device("cpu") | |
logging.info("Using CPU for translations") | |
# Ensure cache directory exists | |
os.makedirs(cache_dir, exist_ok=True) | |
model_name = "facebook/m2m100_418M" # Smaller model | |
try: | |
# Try to load from local cache first | |
self.tokenizer = M2M100Tokenizer.from_pretrained( | |
cache_dir, | |
local_files_only=True | |
) | |
self.model = M2M100ForConditionalGeneration.from_pretrained( | |
cache_dir, | |
local_files_only=True, | |
device_map="cpu", | |
low_cpu_mem_usage=True | |
) | |
except: | |
# If not in cache, download and save | |
self.tokenizer = M2M100Tokenizer.from_pretrained(model_name) | |
self.model = M2M100ForConditionalGeneration.from_pretrained( | |
model_name, | |
device_map="cpu", | |
low_cpu_mem_usage=True | |
) | |
# Save for offline use | |
self.tokenizer.save_pretrained(cache_dir) | |
self.model.save_pretrained(cache_dir) | |
self.model.eval() | |
def translate(self, text: str, source_lang: str, target_lang: str) -> str: | |
try: | |
self.tokenizer.src_lang = source_lang | |
encoded = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True) | |
with torch.no_grad(): | |
generated = self.model.generate( | |
**encoded, | |
forced_bos_token_id=self.tokenizer.get_lang_id(target_lang), | |
max_length=128, | |
num_beams=2, | |
length_penalty=0.6 | |
) | |
return self.tokenizer.batch_decode(generated, skip_special_tokens=True)[0] | |
except Exception as e: | |
logging.error(f"Translation error: {str(e)}") | |
return f"Translation error: {str(e)}" |