from xtransformer import Transformer import torch import torch.nn as nn from nepalitokenizers import SentencePiece from huggingface_hub import hf_hub_download import re # Initialize tokenizers tokenizer_en = SentencePiece() # English tokenizer tokenizer_ne = SentencePiece() # Nepali tokenizer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define special tokens and their IDs START_TOKEN = '' PADDING_TOKEN = '' END_TOKEN = '' SPECIAL_TOKENS = { START_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()), PADDING_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 1, END_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 2, } # Update vocabulary sizes en_vocab_size = tokenizer_en.get_vocab_size() + len(SPECIAL_TOKENS) ne_vocab_size = tokenizer_ne.get_vocab_size() + len(SPECIAL_TOKENS) # Create token-to-index mappings english_to_index = {token: i for i, token in enumerate(tokenizer_en.get_vocab())} nepali_to_index = {token: i for i, token in enumerate(tokenizer_ne.get_vocab())} english_to_index.update(SPECIAL_TOKENS) nepali_to_index.update(SPECIAL_TOKENS) # Hyperparameters max_sequence_length = 100 d_model = 512 batch_size = 32 ffn_hidden = 2048 num_heads = 8 drop_prob = 0.1 encoder_layers = 6 decoder_layers = 4 # Initialize the Transformer model transformer = Transformer( d_model, ffn_hidden, num_heads, drop_prob, encoder_layers, decoder_layers, max_sequence_length, ne_vocab_size, english_to_index, nepali_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN ).to(device) # Function to encode text with special tokens def encode_with_special_tokens(text, tokenizer, max_sequence_length, add_start_end=True): tokens = tokenizer.encode(text).ids if add_start_end: tokens = [SPECIAL_TOKENS[START_TOKEN]] + tokens + [SPECIAL_TOKENS[END_TOKEN]] tokens = tokens[:max_sequence_length] padding = [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - len(tokens)) return tokens + padding # Function to decode token IDs, filtering out special tokens def decode_with_special_tokens(token_ids, tokenizer): token_ids = [token_id for token_id in token_ids if token_id not in SPECIAL_TOKENS.values()] return tokenizer.decode(token_ids) # Mask creation NEG_INFTY = -1e9 def create_masks(eng_batch, decoder_input): batch_size, enc_seq_length = eng_batch.size(0), eng_batch.size(1) dec_seq_length = decoder_input.size(1) device = eng_batch.device encoder_padding_mask = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2) decoder_padding_mask_self = (decoder_input == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2) look_ahead_mask = torch.triu(torch.ones(dec_seq_length, dec_seq_length, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0) decoder_padding_mask_cross = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2) encoder_mask = encoder_padding_mask * NEG_INFTY decoder_self_mask = (look_ahead_mask | decoder_padding_mask_self) * NEG_INFTY decoder_cross_mask = decoder_padding_mask_cross * NEG_INFTY return encoder_mask, decoder_self_mask, decoder_cross_mask # Translation function def translate(sentence): def is_english(text): # Check if the text contains only English letters and spaces using regular expression return re.match(r'^[a-zA-Z\s]+$', text) is not None # Determine which model to use based on input language if is_english(sentence): clean_sentence = re.sub(r'[^a-zA-Z0-9\s]', '', sentence.strip()).lower() checkpoint_file = "checkpoint_en_ne.pth" print('using english to nepali transformer') else: clean_sentence = re.sub(r'[^ऀ-ॿ\s]', '', sentence.strip()) checkpoint_file = "checkpoint_ne_en.pth" print('using nepali to english transformer') # Download the checkpoint from Hugging Face Hub try: checkpoint_path = hf_hub_download( repo_id="bishaltwr/xtransformer", filename=checkpoint_file, repo_type="model" ) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) transformer.load_state_dict(checkpoint['model_state']) transformer.eval() except Exception as e: print(f"Error loading checkpoint: {e}") return f"Translation failed: Could not load model checkpoint ({str(e)})" with torch.no_grad(): eng_tokens = encode_with_special_tokens(clean_sentence, tokenizer_en, max_sequence_length) eng_batch = torch.tensor([eng_tokens]).to(device) ne_batch = torch.tensor([[SPECIAL_TOKENS[START_TOKEN]] + [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - 1)]).to(device) for i in range(1, max_sequence_length): encoder_mask, decoder_mask, cross_mask = create_masks(eng_batch, ne_batch) predictions = transformer(eng_batch, ne_batch, encoder_mask, decoder_mask, cross_mask) next_token = torch.argmax(predictions[:, i - 1, :], dim=-1) if next_token.item() == SPECIAL_TOKENS[END_TOKEN]: break ne_batch[0, i] = next_token return decode_with_special_tokens(ne_batch[0].tolist(), tokenizer_ne)