Ne-En-Trn / inference.py
bishaltwr's picture
init
5b4b058
from xtransformer import Transformer
import torch
import torch.nn as nn
from nepalitokenizers import SentencePiece
from huggingface_hub import hf_hub_download
import re
# Initialize tokenizers
tokenizer_en = SentencePiece() # English tokenizer
tokenizer_ne = SentencePiece() # Nepali tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define special tokens and their IDs
START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'
SPECIAL_TOKENS = {
START_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()),
PADDING_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 1,
END_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 2,
}
# Update vocabulary sizes
en_vocab_size = tokenizer_en.get_vocab_size() + len(SPECIAL_TOKENS)
ne_vocab_size = tokenizer_ne.get_vocab_size() + len(SPECIAL_TOKENS)
# Create token-to-index mappings
english_to_index = {token: i for i, token in enumerate(tokenizer_en.get_vocab())}
nepali_to_index = {token: i for i, token in enumerate(tokenizer_ne.get_vocab())}
english_to_index.update(SPECIAL_TOKENS)
nepali_to_index.update(SPECIAL_TOKENS)
# Hyperparameters
max_sequence_length = 100
d_model = 512
batch_size = 32
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
encoder_layers = 6
decoder_layers = 4
# Initialize the Transformer model
transformer = Transformer(
d_model, ffn_hidden, num_heads, drop_prob, encoder_layers, decoder_layers,
max_sequence_length, ne_vocab_size, english_to_index, nepali_to_index,
START_TOKEN, END_TOKEN, PADDING_TOKEN
).to(device)
# Function to encode text with special tokens
def encode_with_special_tokens(text, tokenizer, max_sequence_length, add_start_end=True):
tokens = tokenizer.encode(text).ids
if add_start_end:
tokens = [SPECIAL_TOKENS[START_TOKEN]] + tokens + [SPECIAL_TOKENS[END_TOKEN]]
tokens = tokens[:max_sequence_length]
padding = [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - len(tokens))
return tokens + padding
# Function to decode token IDs, filtering out special tokens
def decode_with_special_tokens(token_ids, tokenizer):
token_ids = [token_id for token_id in token_ids if token_id not in SPECIAL_TOKENS.values()]
return tokenizer.decode(token_ids)
# Mask creation
NEG_INFTY = -1e9
def create_masks(eng_batch, decoder_input):
batch_size, enc_seq_length = eng_batch.size(0), eng_batch.size(1)
dec_seq_length = decoder_input.size(1)
device = eng_batch.device
encoder_padding_mask = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
decoder_padding_mask_self = (decoder_input == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
look_ahead_mask = torch.triu(torch.ones(dec_seq_length, dec_seq_length, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
decoder_padding_mask_cross = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
encoder_mask = encoder_padding_mask * NEG_INFTY
decoder_self_mask = (look_ahead_mask | decoder_padding_mask_self) * NEG_INFTY
decoder_cross_mask = decoder_padding_mask_cross * NEG_INFTY
return encoder_mask, decoder_self_mask, decoder_cross_mask
# Translation function
def translate(sentence):
def is_english(text):
# Check if the text contains only English letters and spaces using regular expression
return re.match(r'^[a-zA-Z\s]+$', text) is not None
# Determine which model to use based on input language
if is_english(sentence):
clean_sentence = re.sub(r'[^a-zA-Z0-9\s]', '', sentence.strip()).lower()
checkpoint_file = "checkpoint_en_ne.pth"
print('using english to nepali transformer')
else:
clean_sentence = re.sub(r'[^ऀ-ॿ\s]', '', sentence.strip())
checkpoint_file = "checkpoint_ne_en.pth"
print('using nepali to english transformer')
# Download the checkpoint from Hugging Face Hub
try:
checkpoint_path = hf_hub_download(
repo_id="bishaltwr/xtransformer",
filename=checkpoint_file,
repo_type="model"
)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
transformer.load_state_dict(checkpoint['model_state'])
transformer.eval()
except Exception as e:
print(f"Error loading checkpoint: {e}")
return f"Translation failed: Could not load model checkpoint ({str(e)})"
with torch.no_grad():
eng_tokens = encode_with_special_tokens(clean_sentence, tokenizer_en, max_sequence_length)
eng_batch = torch.tensor([eng_tokens]).to(device)
ne_batch = torch.tensor([[SPECIAL_TOKENS[START_TOKEN]] + [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - 1)]).to(device)
for i in range(1, max_sequence_length):
encoder_mask, decoder_mask, cross_mask = create_masks(eng_batch, ne_batch)
predictions = transformer(eng_batch, ne_batch, encoder_mask, decoder_mask, cross_mask)
next_token = torch.argmax(predictions[:, i - 1, :], dim=-1)
if next_token.item() == SPECIAL_TOKENS[END_TOKEN]:
break
ne_batch[0, i] = next_token
return decode_with_special_tokens(ne_batch[0].tolist(), tokenizer_ne)