File size: 5,459 Bytes
5b4b058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from xtransformer import Transformer
import torch
import torch.nn as nn
from nepalitokenizers import SentencePiece
from huggingface_hub import hf_hub_download
import re

# Initialize tokenizers
tokenizer_en = SentencePiece()  # English tokenizer
tokenizer_ne = SentencePiece()  # Nepali tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define special tokens and their IDs
START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'
SPECIAL_TOKENS = {
    START_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()),
    PADDING_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 1,
    END_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 2,
}

# Update vocabulary sizes
en_vocab_size = tokenizer_en.get_vocab_size() + len(SPECIAL_TOKENS)
ne_vocab_size = tokenizer_ne.get_vocab_size() + len(SPECIAL_TOKENS)

# Create token-to-index mappings
english_to_index = {token: i for i, token in enumerate(tokenizer_en.get_vocab())}
nepali_to_index = {token: i for i, token in enumerate(tokenizer_ne.get_vocab())}
english_to_index.update(SPECIAL_TOKENS)
nepali_to_index.update(SPECIAL_TOKENS)

# Hyperparameters
max_sequence_length = 100
d_model = 512
batch_size = 32
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
encoder_layers = 6
decoder_layers = 4

# Initialize the Transformer model
transformer = Transformer(
    d_model, ffn_hidden, num_heads, drop_prob, encoder_layers, decoder_layers,
    max_sequence_length, ne_vocab_size, english_to_index, nepali_to_index,
    START_TOKEN, END_TOKEN, PADDING_TOKEN
).to(device)

# Function to encode text with special tokens
def encode_with_special_tokens(text, tokenizer, max_sequence_length, add_start_end=True):
    tokens = tokenizer.encode(text).ids
    if add_start_end:
        tokens = [SPECIAL_TOKENS[START_TOKEN]] + tokens + [SPECIAL_TOKENS[END_TOKEN]]
    tokens = tokens[:max_sequence_length]
    padding = [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - len(tokens))
    return tokens + padding

# Function to decode token IDs, filtering out special tokens
def decode_with_special_tokens(token_ids, tokenizer):
    token_ids = [token_id for token_id in token_ids if token_id not in SPECIAL_TOKENS.values()]
    return tokenizer.decode(token_ids)

# Mask creation
NEG_INFTY = -1e9
def create_masks(eng_batch, decoder_input):
    batch_size, enc_seq_length = eng_batch.size(0), eng_batch.size(1)
    dec_seq_length = decoder_input.size(1)
    device = eng_batch.device

    encoder_padding_mask = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
    decoder_padding_mask_self = (decoder_input == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
    look_ahead_mask = torch.triu(torch.ones(dec_seq_length, dec_seq_length, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
    decoder_padding_mask_cross = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)

    encoder_mask = encoder_padding_mask * NEG_INFTY
    decoder_self_mask = (look_ahead_mask | decoder_padding_mask_self) * NEG_INFTY
    decoder_cross_mask = decoder_padding_mask_cross * NEG_INFTY

    return encoder_mask, decoder_self_mask, decoder_cross_mask

# Translation function
def translate(sentence):
    def is_english(text):
        # Check if the text contains only English letters and spaces using regular expression
        return re.match(r'^[a-zA-Z\s]+$', text) is not None
    # Determine which model to use based on input language
    if is_english(sentence):
        clean_sentence = re.sub(r'[^a-zA-Z0-9\s]', '', sentence.strip()).lower() 
        checkpoint_file = "checkpoint_en_ne.pth"
        print('using english to nepali transformer')
    else:
        clean_sentence = re.sub(r'[^ऀ-ॿ\s]', '', sentence.strip())
        checkpoint_file = "checkpoint_ne_en.pth"
        print('using nepali to english transformer')
    
    # Download the checkpoint from Hugging Face Hub
    try:
        checkpoint_path = hf_hub_download(
            repo_id="bishaltwr/xtransformer", 
            filename=checkpoint_file,
            repo_type="model"
        )
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        transformer.load_state_dict(checkpoint['model_state'])
        transformer.eval()
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return f"Translation failed: Could not load model checkpoint ({str(e)})"

    with torch.no_grad():
        eng_tokens = encode_with_special_tokens(clean_sentence, tokenizer_en, max_sequence_length)
        eng_batch = torch.tensor([eng_tokens]).to(device)
        ne_batch = torch.tensor([[SPECIAL_TOKENS[START_TOKEN]] + [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - 1)]).to(device)

        for i in range(1, max_sequence_length):
            encoder_mask, decoder_mask, cross_mask = create_masks(eng_batch, ne_batch)
            predictions = transformer(eng_batch, ne_batch, encoder_mask, decoder_mask, cross_mask)
            next_token = torch.argmax(predictions[:, i - 1, :], dim=-1)
            if next_token.item() == SPECIAL_TOKENS[END_TOKEN]:
                break
            ne_batch[0, i] = next_token

    return decode_with_special_tokens(ne_batch[0].tolist(), tokenizer_ne)