|
from xtransformer import Transformer
|
|
import torch
|
|
import torch.nn as nn
|
|
from nepalitokenizers import SentencePiece
|
|
from huggingface_hub import hf_hub_download
|
|
import re
|
|
|
|
|
|
tokenizer_en = SentencePiece()
|
|
tokenizer_ne = SentencePiece()
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
START_TOKEN = '<START>'
|
|
PADDING_TOKEN = '<PADDING>'
|
|
END_TOKEN = '<END>'
|
|
SPECIAL_TOKENS = {
|
|
START_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()),
|
|
PADDING_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 1,
|
|
END_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 2,
|
|
}
|
|
|
|
|
|
en_vocab_size = tokenizer_en.get_vocab_size() + len(SPECIAL_TOKENS)
|
|
ne_vocab_size = tokenizer_ne.get_vocab_size() + len(SPECIAL_TOKENS)
|
|
|
|
|
|
english_to_index = {token: i for i, token in enumerate(tokenizer_en.get_vocab())}
|
|
nepali_to_index = {token: i for i, token in enumerate(tokenizer_ne.get_vocab())}
|
|
english_to_index.update(SPECIAL_TOKENS)
|
|
nepali_to_index.update(SPECIAL_TOKENS)
|
|
|
|
|
|
max_sequence_length = 100
|
|
d_model = 512
|
|
batch_size = 32
|
|
ffn_hidden = 2048
|
|
num_heads = 8
|
|
drop_prob = 0.1
|
|
encoder_layers = 6
|
|
decoder_layers = 4
|
|
|
|
|
|
transformer = Transformer(
|
|
d_model, ffn_hidden, num_heads, drop_prob, encoder_layers, decoder_layers,
|
|
max_sequence_length, ne_vocab_size, english_to_index, nepali_to_index,
|
|
START_TOKEN, END_TOKEN, PADDING_TOKEN
|
|
).to(device)
|
|
|
|
|
|
def encode_with_special_tokens(text, tokenizer, max_sequence_length, add_start_end=True):
|
|
tokens = tokenizer.encode(text).ids
|
|
if add_start_end:
|
|
tokens = [SPECIAL_TOKENS[START_TOKEN]] + tokens + [SPECIAL_TOKENS[END_TOKEN]]
|
|
tokens = tokens[:max_sequence_length]
|
|
padding = [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - len(tokens))
|
|
return tokens + padding
|
|
|
|
|
|
def decode_with_special_tokens(token_ids, tokenizer):
|
|
token_ids = [token_id for token_id in token_ids if token_id not in SPECIAL_TOKENS.values()]
|
|
return tokenizer.decode(token_ids)
|
|
|
|
|
|
NEG_INFTY = -1e9
|
|
def create_masks(eng_batch, decoder_input):
|
|
batch_size, enc_seq_length = eng_batch.size(0), eng_batch.size(1)
|
|
dec_seq_length = decoder_input.size(1)
|
|
device = eng_batch.device
|
|
|
|
encoder_padding_mask = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
|
|
decoder_padding_mask_self = (decoder_input == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
|
|
look_ahead_mask = torch.triu(torch.ones(dec_seq_length, dec_seq_length, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
|
decoder_padding_mask_cross = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
|
|
|
|
encoder_mask = encoder_padding_mask * NEG_INFTY
|
|
decoder_self_mask = (look_ahead_mask | decoder_padding_mask_self) * NEG_INFTY
|
|
decoder_cross_mask = decoder_padding_mask_cross * NEG_INFTY
|
|
|
|
return encoder_mask, decoder_self_mask, decoder_cross_mask
|
|
|
|
|
|
def translate(sentence):
|
|
def is_english(text):
|
|
|
|
return re.match(r'^[a-zA-Z\s]+$', text) is not None
|
|
|
|
if is_english(sentence):
|
|
clean_sentence = re.sub(r'[^a-zA-Z0-9\s]', '', sentence.strip()).lower()
|
|
checkpoint_file = "checkpoint_en_ne.pth"
|
|
print('using english to nepali transformer')
|
|
else:
|
|
clean_sentence = re.sub(r'[^ऀ-ॿ\s]', '', sentence.strip())
|
|
checkpoint_file = "checkpoint_ne_en.pth"
|
|
print('using nepali to english transformer')
|
|
|
|
|
|
try:
|
|
checkpoint_path = hf_hub_download(
|
|
repo_id="bishaltwr/xtransformer",
|
|
filename=checkpoint_file,
|
|
repo_type="model"
|
|
)
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
transformer.load_state_dict(checkpoint['model_state'])
|
|
transformer.eval()
|
|
except Exception as e:
|
|
print(f"Error loading checkpoint: {e}")
|
|
return f"Translation failed: Could not load model checkpoint ({str(e)})"
|
|
|
|
with torch.no_grad():
|
|
eng_tokens = encode_with_special_tokens(clean_sentence, tokenizer_en, max_sequence_length)
|
|
eng_batch = torch.tensor([eng_tokens]).to(device)
|
|
ne_batch = torch.tensor([[SPECIAL_TOKENS[START_TOKEN]] + [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - 1)]).to(device)
|
|
|
|
for i in range(1, max_sequence_length):
|
|
encoder_mask, decoder_mask, cross_mask = create_masks(eng_batch, ne_batch)
|
|
predictions = transformer(eng_batch, ne_batch, encoder_mask, decoder_mask, cross_mask)
|
|
next_token = torch.argmax(predictions[:, i - 1, :], dim=-1)
|
|
if next_token.item() == SPECIAL_TOKENS[END_TOKEN]:
|
|
break
|
|
ne_batch[0, i] = next_token
|
|
|
|
return decode_with_special_tokens(ne_batch[0].tolist(), tokenizer_ne)
|
|
|