bishaltwr commited on
Commit
5b4b058
·
1 Parent(s): b062ffb
Files changed (4) hide show
  1. app.py +214 -0
  2. inference.py +124 -0
  3. requirements.txt +9 -0
  4. xtransformer.py +338 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import io
5
+ from gtts import gTTS
6
+ import soundfile as sf
7
+ import tempfile
8
+ import logging
9
+
10
+ # Import your existing functionality
11
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
12
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
13
+
14
+ logging.basicConfig(
15
+ level=logging.DEBUG,
16
+ format='%(asctime)s - %(levelname)s - %(message)s'
17
+ )
18
+
19
+ # Initialize translation model
20
+ checkpoint_dir = "bishaltwr/final_m2m100" # Change to Hugging Face model ID when deployed
21
+ try:
22
+ tokenizer = M2M100Tokenizer.from_pretrained(checkpoint_dir)
23
+ model_m2m = M2M100ForConditionalGeneration.from_pretrained(checkpoint_dir)
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model_m2m.to(device)
26
+ m2m_available = True
27
+ except Exception as e:
28
+ logging.error(f"Error loading M2M100 model: {e}")
29
+ m2m_available = False
30
+
31
+ # Initialize ASR model
32
+ model_id = "bishaltwr/wav2vec2-large-mms-1b-nepali"
33
+ try:
34
+ processor = AutoProcessor.from_pretrained(model_id)
35
+ model_asr = Wav2Vec2ForCTC.from_pretrained(model_id, ignore_mismatched_sizes=True)
36
+ asr_available = True
37
+ except Exception as e:
38
+ logging.error(f"Error loading ASR model: {e}")
39
+ asr_available = False
40
+
41
+ # Initialize X-Transformer model
42
+ try:
43
+ from inference import translate as xtranslate
44
+ xtransformer_available = True
45
+ except Exception as e:
46
+ logging.error(f"Error loading XTransformer model: {e}")
47
+ xtransformer_available = False
48
+
49
+ def m2m_translate(text, source_lang, target_lang):
50
+ """Translation using M2M100 model"""
51
+ if not m2m_available:
52
+ return "M2M100 model not available"
53
+
54
+ tokenizer.src_lang = source_lang
55
+ inputs = tokenizer(text, return_tensors="pt").to(device)
56
+ translated_tokens = model_m2m.generate(
57
+ **inputs,
58
+ forced_bos_token_id=tokenizer.get_lang_id(target_lang)
59
+ )
60
+ translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
61
+ return translated_text
62
+
63
+ def transcribe_audio(audio_path, language="npi"):
64
+ """Transcribe audio using ASR model"""
65
+ if not asr_available:
66
+ return "ASR model not available"
67
+
68
+ import librosa
69
+ audio, sr = librosa.load(audio_path, sr=16000)
70
+ processor.tokenizer.set_target_lang(language)
71
+ model_asr.load_adapter(language)
72
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
73
+
74
+ with torch.no_grad():
75
+ outputs = model_asr(**inputs).logits
76
+
77
+ ids = torch.argmax(outputs, dim=-1)[0]
78
+ transcription = processor.decode(ids, skip_special_tokens=True)
79
+
80
+ if language == "eng":
81
+ transcription = transcription.replace('<pad>','').replace('<unk>','')
82
+ else:
83
+ transcription = transcription.replace('<pad>',' ').replace('<unk>','')
84
+
85
+ return transcription
86
+
87
+ def text_to_speech(text):
88
+ """Convert text to speech using gTTS"""
89
+ if not text:
90
+ return None
91
+
92
+ try:
93
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
94
+ tts = gTTS(text=text)
95
+ tts.save(temp_audio.name)
96
+ return temp_audio.name
97
+ except Exception as e:
98
+ logging.error(f"TTS error: {e}")
99
+ return None
100
+
101
+ def detect_language(text):
102
+ """Simple language detection function"""
103
+ english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
104
+ return "en" if english_chars > len(text) * 0.5 else "ne"
105
+
106
+ def translate_text(text, model_choice, source_lang=None, target_lang=None):
107
+ """Main translation function"""
108
+ if not text:
109
+ return "Please enter some text to translate"
110
+
111
+ # Auto-detect language if not specified
112
+ if not source_lang:
113
+ source_lang = detect_language(text)
114
+ target_lang = "ne" if source_lang == "en" else "en"
115
+
116
+ # Choose the translation model
117
+ if model_choice == "XTransformer" and xtransformer_available:
118
+ return xtranslate(text)
119
+ elif model_choice == "M2M100" and m2m_available:
120
+ return m2m_translate(text, source_lang=source_lang, target_lang=target_lang)
121
+ else:
122
+ return "Selected model is not available"
123
+
124
+ # Set up the Gradio interface
125
+ with gr.Blocks(title="Nepali-English Translator") as demo:
126
+ gr.Markdown("# Nepali-English Translation Service")
127
+ gr.Markdown("Translate between Nepali and English, transcribe audio, and convert text to speech.")
128
+
129
+ # Set up tabs for different functions
130
+ with gr.Tabs():
131
+ # Text Translation Tab
132
+ with gr.TabItem("Text Translation"):
133
+ with gr.Row():
134
+ with gr.Column():
135
+ text_input = gr.Textbox(label="Input Text", lines=5)
136
+
137
+ with gr.Row():
138
+ model_choice = gr.Radio(
139
+ choices=["XTransformer", "M2M100"],
140
+ value="XTransformer",
141
+ label="Translation Model"
142
+ )
143
+
144
+ with gr.Row():
145
+ source_lang = gr.Dropdown(
146
+ choices=["Auto-detect", "en", "ne"],
147
+ value="Auto-detect",
148
+ label="Source Language",
149
+ visible=True
150
+ )
151
+ target_lang = gr.Dropdown(
152
+ choices=["Auto-select", "en", "ne"],
153
+ value="Auto-select",
154
+ label="Target Language",
155
+ visible=True
156
+ )
157
+
158
+ translate_button = gr.Button("Translate")
159
+
160
+ with gr.Column():
161
+ translation_output = gr.Textbox(label="Translation Output", lines=5)
162
+ tts_button = gr.Button("Convert to Speech")
163
+ audio_output = gr.Audio(label="Audio Output")
164
+
165
+ # Speech to Text Tab
166
+ with gr.TabItem("Speech to Text"):
167
+ with gr.Column():
168
+ audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
169
+ asr_language = gr.Radio(
170
+ choices=["eng", "npi"],
171
+ value="npi",
172
+ label="Speech Language"
173
+ )
174
+ transcribe_button = gr.Button("Transcribe")
175
+ transcription_output = gr.Textbox(label="Transcription Output", lines=3)
176
+
177
+ # Define event handlers
178
+ def process_translation(text, model, src_lang, tgt_lang):
179
+ if src_lang == "Auto-detect":
180
+ src_lang = None
181
+ if tgt_lang == "Auto-select":
182
+ tgt_lang = None
183
+ return translate_text(text, model, src_lang, tgt_lang)
184
+
185
+ def process_tts(text):
186
+ return text_to_speech(text)
187
+
188
+ def process_transcription(audio_path, language):
189
+ if not audio_path:
190
+ return "Please upload or record audio"
191
+ return transcribe_audio(audio_path, language)
192
+
193
+ # Connect the components
194
+ translate_button.click(
195
+ process_translation,
196
+ inputs=[text_input, model_choice, source_lang, target_lang],
197
+ outputs=translation_output
198
+ )
199
+
200
+ tts_button.click(
201
+ process_tts,
202
+ inputs=translation_output,
203
+ outputs=audio_output
204
+ )
205
+
206
+ transcribe_button.click(
207
+ process_transcription,
208
+ inputs=[audio_input, asr_language],
209
+ outputs=transcription_output
210
+ )
211
+
212
+ # Launch the app
213
+ if __name__ == "__main__":
214
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xtransformer import Transformer
2
+ import torch
3
+ import torch.nn as nn
4
+ from nepalitokenizers import SentencePiece
5
+ from huggingface_hub import hf_hub_download
6
+ import re
7
+
8
+ # Initialize tokenizers
9
+ tokenizer_en = SentencePiece() # English tokenizer
10
+ tokenizer_ne = SentencePiece() # Nepali tokenizer
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # Define special tokens and their IDs
14
+ START_TOKEN = '<START>'
15
+ PADDING_TOKEN = '<PADDING>'
16
+ END_TOKEN = '<END>'
17
+ SPECIAL_TOKENS = {
18
+ START_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()),
19
+ PADDING_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 1,
20
+ END_TOKEN: max(tokenizer_en.get_vocab_size(), tokenizer_ne.get_vocab_size()) + 2,
21
+ }
22
+
23
+ # Update vocabulary sizes
24
+ en_vocab_size = tokenizer_en.get_vocab_size() + len(SPECIAL_TOKENS)
25
+ ne_vocab_size = tokenizer_ne.get_vocab_size() + len(SPECIAL_TOKENS)
26
+
27
+ # Create token-to-index mappings
28
+ english_to_index = {token: i for i, token in enumerate(tokenizer_en.get_vocab())}
29
+ nepali_to_index = {token: i for i, token in enumerate(tokenizer_ne.get_vocab())}
30
+ english_to_index.update(SPECIAL_TOKENS)
31
+ nepali_to_index.update(SPECIAL_TOKENS)
32
+
33
+ # Hyperparameters
34
+ max_sequence_length = 100
35
+ d_model = 512
36
+ batch_size = 32
37
+ ffn_hidden = 2048
38
+ num_heads = 8
39
+ drop_prob = 0.1
40
+ encoder_layers = 6
41
+ decoder_layers = 4
42
+
43
+ # Initialize the Transformer model
44
+ transformer = Transformer(
45
+ d_model, ffn_hidden, num_heads, drop_prob, encoder_layers, decoder_layers,
46
+ max_sequence_length, ne_vocab_size, english_to_index, nepali_to_index,
47
+ START_TOKEN, END_TOKEN, PADDING_TOKEN
48
+ ).to(device)
49
+
50
+ # Function to encode text with special tokens
51
+ def encode_with_special_tokens(text, tokenizer, max_sequence_length, add_start_end=True):
52
+ tokens = tokenizer.encode(text).ids
53
+ if add_start_end:
54
+ tokens = [SPECIAL_TOKENS[START_TOKEN]] + tokens + [SPECIAL_TOKENS[END_TOKEN]]
55
+ tokens = tokens[:max_sequence_length]
56
+ padding = [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - len(tokens))
57
+ return tokens + padding
58
+
59
+ # Function to decode token IDs, filtering out special tokens
60
+ def decode_with_special_tokens(token_ids, tokenizer):
61
+ token_ids = [token_id for token_id in token_ids if token_id not in SPECIAL_TOKENS.values()]
62
+ return tokenizer.decode(token_ids)
63
+
64
+ # Mask creation
65
+ NEG_INFTY = -1e9
66
+ def create_masks(eng_batch, decoder_input):
67
+ batch_size, enc_seq_length = eng_batch.size(0), eng_batch.size(1)
68
+ dec_seq_length = decoder_input.size(1)
69
+ device = eng_batch.device
70
+
71
+ encoder_padding_mask = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
72
+ decoder_padding_mask_self = (decoder_input == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
73
+ look_ahead_mask = torch.triu(torch.ones(dec_seq_length, dec_seq_length, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
74
+ decoder_padding_mask_cross = (eng_batch == SPECIAL_TOKENS[PADDING_TOKEN]).unsqueeze(1).unsqueeze(2)
75
+
76
+ encoder_mask = encoder_padding_mask * NEG_INFTY
77
+ decoder_self_mask = (look_ahead_mask | decoder_padding_mask_self) * NEG_INFTY
78
+ decoder_cross_mask = decoder_padding_mask_cross * NEG_INFTY
79
+
80
+ return encoder_mask, decoder_self_mask, decoder_cross_mask
81
+
82
+ # Translation function
83
+ def translate(sentence):
84
+ def is_english(text):
85
+ # Check if the text contains only English letters and spaces using regular expression
86
+ return re.match(r'^[a-zA-Z\s]+$', text) is not None
87
+ # Determine which model to use based on input language
88
+ if is_english(sentence):
89
+ clean_sentence = re.sub(r'[^a-zA-Z0-9\s]', '', sentence.strip()).lower()
90
+ checkpoint_file = "checkpoint_en_ne.pth"
91
+ print('using english to nepali transformer')
92
+ else:
93
+ clean_sentence = re.sub(r'[^ऀ-ॿ\s]', '', sentence.strip())
94
+ checkpoint_file = "checkpoint_ne_en.pth"
95
+ print('using nepali to english transformer')
96
+
97
+ # Download the checkpoint from Hugging Face Hub
98
+ try:
99
+ checkpoint_path = hf_hub_download(
100
+ repo_id="bishaltwr/xtransformer",
101
+ filename=checkpoint_file,
102
+ repo_type="model"
103
+ )
104
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
105
+ transformer.load_state_dict(checkpoint['model_state'])
106
+ transformer.eval()
107
+ except Exception as e:
108
+ print(f"Error loading checkpoint: {e}")
109
+ return f"Translation failed: Could not load model checkpoint ({str(e)})"
110
+
111
+ with torch.no_grad():
112
+ eng_tokens = encode_with_special_tokens(clean_sentence, tokenizer_en, max_sequence_length)
113
+ eng_batch = torch.tensor([eng_tokens]).to(device)
114
+ ne_batch = torch.tensor([[SPECIAL_TOKENS[START_TOKEN]] + [SPECIAL_TOKENS[PADDING_TOKEN]] * (max_sequence_length - 1)]).to(device)
115
+
116
+ for i in range(1, max_sequence_length):
117
+ encoder_mask, decoder_mask, cross_mask = create_masks(eng_batch, ne_batch)
118
+ predictions = transformer(eng_batch, ne_batch, encoder_mask, decoder_mask, cross_mask)
119
+ next_token = torch.argmax(predictions[:, i - 1, :], dim=-1)
120
+ if next_token.item() == SPECIAL_TOKENS[END_TOKEN]:
121
+ break
122
+ ne_batch[0, i] = next_token
123
+
124
+ return decode_with_special_tokens(ne_batch[0].tolist(), tokenizer_ne)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.20.1
2
+ torch
3
+ transformers
4
+ librosa
5
+ soundfile
6
+ gtts
7
+ nepalitokenizers
8
+ sounddevice
9
+ hf_hub_download
xtransformer.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import numpy as np # Unused import
2
+ import torch
3
+ import math
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from nepalitokenizers import SentencePiece
7
+ from torch.amp import autocast # Mixed precision
8
+ from torch.utils.checkpoint import checkpoint # Gradient checkpointing
9
+
10
+ # Device setup
11
+ def get_device():
12
+ return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
13
+
14
+ # Efficient Scaled Dot-Product Attention
15
+ def scaled_dot_product(q, k, v, mask=None):
16
+ d_k = q.size()[-1]
17
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # Simplified attention computation
18
+ if mask is not None:
19
+ scores += mask
20
+ attention = F.softmax(scores, dim=-1)
21
+ values = torch.matmul(attention, v)
22
+ return values, attention
23
+
24
+ # Precompute Positional Encoding
25
+ class PositionalEncoding(nn.Module):
26
+ def __init__(self, d_model, max_sequence_length):
27
+ super().__init__()
28
+ self.max_sequence_length = max_sequence_length
29
+ self.d_model = d_model
30
+ self.pe = self._create_positional_encoding() # Precompute during initialization
31
+
32
+ def _create_positional_encoding(self):
33
+ position = torch.arange(self.max_sequence_length).unsqueeze(1)
34
+ div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
35
+ pe = torch.zeros(self.max_sequence_length, self.d_model)
36
+ pe[:, 0::2] = torch.sin(position * div_term)
37
+ pe[:, 1::2] = torch.cos(position * div_term)
38
+ return pe
39
+
40
+ def forward(self, x):
41
+ seq_length = x.size(1) # Handle variable sequence lengths
42
+ return self.pe[:seq_length, :].to(x.device)
43
+
44
+ # Efficient Sentence Embedding with Caching
45
+ class SentenceEmbedding(nn.Module):
46
+ def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
47
+ super().__init__()
48
+ self.vocab_size = len(language_to_index)
49
+ self.max_sequence_length = max_sequence_length
50
+ self.embedding = nn.Embedding(self.vocab_size, d_model)
51
+ self.language_to_index = language_to_index
52
+ self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
53
+ self.dropout = nn.Dropout(p=0.1)
54
+ self.START_TOKEN = START_TOKEN
55
+ self.END_TOKEN = END_TOKEN
56
+ self.PADDING_TOKEN = PADDING_TOKEN
57
+ self.tokenizer = SentencePiece()
58
+
59
+ class SentenceEmbedding(nn.Module):
60
+ def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
61
+ super().__init__()
62
+ self.vocab_size = len(language_to_index)
63
+ self.max_sequence_length = max_sequence_length
64
+ self.embedding = nn.Embedding(self.vocab_size, d_model)
65
+ self.language_to_index = language_to_index
66
+ self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
67
+ self.dropout = nn.Dropout(p=0.1)
68
+ self.START_TOKEN = START_TOKEN
69
+ self.END_TOKEN = END_TOKEN
70
+ self.PADDING_TOKEN = PADDING_TOKEN
71
+ self.tokenizer = SentencePiece()
72
+
73
+ def batch_tokenize(self, batch, start_token, end_token):
74
+ """
75
+ Tokenizes a batch of sentences or processes pre-tokenized tensors.
76
+
77
+ Args:
78
+ batch: A list of sentences (str) or a tensor of token IDs.
79
+ start_token: Whether to add a start token.
80
+ end_token: Whether to add an end token.
81
+
82
+ Returns:
83
+ A tensor of token IDs with shape (batch_size, seq_len).
84
+ """
85
+ # If input is already a tensor, return it directly
86
+ if isinstance(batch, torch.Tensor):
87
+ return batch.to(get_device())
88
+
89
+ # Process raw text inputs
90
+ token_ids = []
91
+ for sentence in batch:
92
+ if not isinstance(sentence, str):
93
+ sentence = str(sentence).strip()
94
+ if not sentence:
95
+ sentence = self.PADDING_TOKEN
96
+ try:
97
+ tokens = self.tokenizer.encode(sentence)
98
+ token_ids.append(tokens.ids)
99
+ except Exception:
100
+ print(f"Error tokenizing: {sentence}")
101
+ token_ids.append([self.language_to_index.get(self.PADDING_TOKEN, 0)])
102
+
103
+ # Add start and end tokens if required
104
+ if start_token:
105
+ token_ids = [[self.language_to_index.get(self.START_TOKEN, self.PADDING_TOKEN)] + ids for ids in token_ids]
106
+ if end_token:
107
+ token_ids = [ids + [self.language_to_index.get(self.END_TOKEN, self.PADDING_TOKEN)] for ids in token_ids]
108
+
109
+ # Truncate sequences to max_sequence_length
110
+ token_ids = [ids[:self.max_sequence_length] for ids in token_ids]
111
+
112
+ # Pad sequences to max_sequence_length
113
+ token_ids = torch.nn.utils.rnn.pad_sequence(
114
+ [torch.tensor(ids, dtype=torch.long) for ids in token_ids],
115
+ batch_first=True,
116
+ padding_value=self.language_to_index.get(self.PADDING_TOKEN, 0)
117
+ ).to(get_device())
118
+
119
+ return token_ids
120
+
121
+ def forward(self, x, start_token, end_token):
122
+ """
123
+ Forward pass for the SentenceEmbedding module.
124
+
125
+ Args:
126
+ x: Input batch (list of sentences or tensor of token IDs).
127
+ start_token: Whether to add a start token.
128
+ end_token: Whether to add an end token.
129
+
130
+ Returns:
131
+ Embedded and positional-encoded output tensor.
132
+ """
133
+ # Tokenize input if it's raw text
134
+ if not isinstance(x, torch.Tensor):
135
+ x = self.batch_tokenize(x, start_token, end_token)
136
+
137
+ # Embed tokens and add positional encoding
138
+ x = self.embedding(x)
139
+ pos = self.position_encoder(x)
140
+ x = self.dropout(x + pos)
141
+ return x
142
+ def forward(self, x, start_token, end_token):
143
+ # If x is already a tensor, skip tokenization
144
+ if not isinstance(x, torch.Tensor):
145
+ x = self.batch_tokenize(x, start_token, end_token)
146
+ x = self.embedding(x)
147
+ pos = self.position_encoder(x)
148
+ x = self.dropout(x + pos)
149
+ return x
150
+
151
+ # Multi-Head Attention with Efficient Matrix Operations
152
+ class MultiHeadAttention(nn.Module):
153
+ def __init__(self, d_model, num_heads):
154
+ super().__init__()
155
+ self.d_model = d_model
156
+ self.num_heads = num_heads
157
+ self.head_dim = d_model // num_heads
158
+ self.qkv_layer = nn.Linear(d_model, 3 * d_model)
159
+ self.linear_layer = nn.Linear(d_model, d_model)
160
+
161
+ def forward(self, x, mask):
162
+ batch_size, seq_length, d_model = x.size()
163
+ qkv = self.qkv_layer(x)
164
+ qkv = qkv.view(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
165
+ qkv = qkv.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_length, 3 * head_dim)
166
+ q, k, v = qkv.chunk(3, dim=-1)
167
+ values, _ = scaled_dot_product(q, k, v, mask) # Ignore unused variable 'attention'
168
+ values = values.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, d_model)
169
+ out = self.linear_layer(values)
170
+ return out
171
+
172
+ # Multi-Head Cross Attention
173
+ class MultiHeadCrossAttention(nn.Module):
174
+ def __init__(self, d_model, num_heads):
175
+ super().__init__()
176
+ self.d_model = d_model
177
+ self.num_heads = num_heads
178
+ self.head_dim = d_model // num_heads
179
+ self.kv_layer = nn.Linear(d_model, 2 * d_model)
180
+ self.q_layer = nn.Linear(d_model, d_model)
181
+ self.linear_layer = nn.Linear(d_model, d_model)
182
+
183
+ def forward(self, x, y, mask):
184
+ batch_size, x_seq_length, _ = x.size() # Encoder sequence length
185
+ batch_size, y_seq_length, _ = y.size() # Decoder sequence length
186
+
187
+ # Process encoder output (x) for Key/Value
188
+ kv = self.kv_layer(x)
189
+ kv = kv.view(batch_size, x_seq_length, self.num_heads, 2 * self.head_dim)
190
+ kv = kv.permute(0, 2, 1, 3) # [batch, heads, x_seq, 2*head_dim]
191
+ k, v = kv.chunk(2, dim=-1) # Each [batch, heads, x_seq, head_dim]
192
+
193
+ # Process decoder input (y) for Query
194
+ q = self.q_layer(y)
195
+ q = q.view(batch_size, y_seq_length, self.num_heads, self.head_dim)
196
+ q = q.permute(0, 2, 1, 3) # [batch, heads, y_seq, head_dim]
197
+
198
+ # Compute attention
199
+ values, _ = scaled_dot_product(q, k, v, mask)
200
+
201
+ # Reshape back to original dimensions
202
+ values = values.permute(0, 2, 1, 3).contiguous()
203
+ values = values.view(batch_size, y_seq_length, self.d_model)
204
+ return self.linear_layer(values)
205
+
206
+ # Layer Normalization
207
+ class LayerNormalization(nn.Module):
208
+ def __init__(self, parameters_shape, eps=1e-5):
209
+ super().__init__()
210
+ self.layer_norm = nn.LayerNorm(parameters_shape, eps=eps)
211
+
212
+ def forward(self, inputs):
213
+ return self.layer_norm(inputs)
214
+
215
+ # Position-wise Feed-Forward Network
216
+ class PositionwiseFeedForward(nn.Module):
217
+ def __init__(self, d_model, hidden, drop_prob=0.1):
218
+ super().__init__()
219
+ self.linear1 = nn.Linear(d_model, hidden)
220
+ self.linear2 = nn.Linear(hidden, d_model)
221
+ self.relu = nn.ReLU()
222
+ self.dropout = nn.Dropout(p=drop_prob)
223
+
224
+ def forward(self, x):
225
+ x = self.linear1(x)
226
+ x = self.relu(x)
227
+ x = self.dropout(x)
228
+ x = self.linear2(x)
229
+ return x
230
+
231
+ # Encoder Layer with Gradient Checkpointing
232
+ class EncoderLayer(nn.Module):
233
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
234
+ super().__init__()
235
+ self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
236
+ self.norm1 = LayerNormalization(parameters_shape=[d_model])
237
+ self.dropout1 = nn.Dropout(p=drop_prob)
238
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
239
+ self.norm2 = LayerNormalization(parameters_shape=[d_model])
240
+ self.dropout2 = nn.Dropout(p=drop_prob)
241
+
242
+ def forward(self, x, self_attention_mask):
243
+ residual_x = x.clone()
244
+ x = checkpoint(self.attention, x, self_attention_mask, preserve_rng_state=True, use_reentrant=False) # Gradient checkpointing
245
+ x = self.dropout1(x)
246
+ x = self.norm1(x + residual_x)
247
+ residual_x = x.clone()
248
+ x = checkpoint(self.ffn, x, preserve_rng_state=True, use_reentrant=False) # Gradient checkpointing
249
+ x = self.dropout2(x)
250
+ x = self.norm2(x + residual_x)
251
+ return x
252
+
253
+ # Sequential Encoder
254
+ class SequentialEncoder(nn.Sequential):
255
+ def forward(self, *inputs):
256
+ x, self_attention_mask = inputs
257
+ for module in self._modules.values():
258
+ x = module(x, self_attention_mask)
259
+ return x
260
+
261
+ # Encoder with Mixed Precision
262
+ class Encoder(nn.Module):
263
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, encoder_layer, max_sequence_length, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
264
+ super().__init__()
265
+ self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
266
+ self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(encoder_layer)])
267
+
268
+ def forward(self, x, self_attention_mask, start_token, end_token):
269
+ with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'): # Mixed precision
270
+ x = self.sentence_embedding(x, start_token, end_token)
271
+ x = self.layers(x, self_attention_mask)
272
+ return x
273
+
274
+ # Decoder Layer with Gradient Checkpointing
275
+ class DecoderLayer(nn.Module):
276
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
277
+ super().__init__()
278
+ self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
279
+ self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
280
+ self.dropout1 = nn.Dropout(p=drop_prob)
281
+ self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
282
+ self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
283
+ self.dropout2 = nn.Dropout(p=drop_prob)
284
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
285
+ self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
286
+ self.dropout3 = nn.Dropout(p=drop_prob)
287
+
288
+ def forward(self, x, y, self_attention_mask, cross_attention_mask):
289
+ _y = y.clone()
290
+ y = checkpoint(self.self_attention, y, self_attention_mask, preserve_rng_state=True, use_reentrant=False) # Gradient checkpointing
291
+ y = self.dropout1(y)
292
+ y = self.layer_norm1(y + _y)
293
+ _y = y.clone()
294
+ y = checkpoint(self.encoder_decoder_attention, x, y, cross_attention_mask, preserve_rng_state=True, use_reentrant=False) # Gradient checkpointing
295
+ y = self.dropout2(y)
296
+ y = self.layer_norm2(y + _y)
297
+ _y = y.clone()
298
+ y = checkpoint(self.ffn, y, preserve_rng_state=True, use_reentrant=False) # Gradient checkpointing
299
+ y = self.dropout3(y)
300
+ y = self.layer_norm3(y + _y)
301
+ return y
302
+
303
+ # Sequential Decoder
304
+ class SequentialDecoder(nn.Sequential):
305
+ def forward(self, *inputs):
306
+ x, y, self_attention_mask, cross_attention_mask = inputs
307
+ for module in self._modules.values():
308
+ y = module(x, y, self_attention_mask, cross_attention_mask)
309
+ return y
310
+
311
+ # Decoder with Mixed Precision
312
+ class Decoder(nn.Module):
313
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, decoder_layer, max_sequence_length, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
314
+ super().__init__()
315
+ self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
316
+ self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(decoder_layer)])
317
+
318
+ def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
319
+ with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'): # Mixed precision
320
+ y = self.sentence_embedding(y, start_token, end_token)
321
+ y = self.layers(x, y, self_attention_mask, cross_attention_mask)
322
+ return y
323
+
324
+ # Transformer with Mixed Precision and Gradient Checkpointing
325
+ class Transformer(nn.Module):
326
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, encoder_layer, decoder_layer, max_sequence_length, ne_vocab_size, english_to_index, nepali_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
327
+ super().__init__()
328
+ self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, encoder_layer, max_sequence_length, english_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
329
+ self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, decoder_layer, max_sequence_length, nepali_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
330
+ self.linear = nn.Linear(d_model, ne_vocab_size)
331
+ self.device = get_device()
332
+
333
+ def forward(self, x, y, encoder_self_attention_mask=None, decoder_self_attention_mask=None, decoder_cross_attention_mask=None, enc_start_token=False, enc_end_token=False, dec_start_token=False, dec_end_token=False):
334
+ with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'): # Mixed precision
335
+ x = self.encoder(x, encoder_self_attention_mask, enc_start_token, enc_end_token)
336
+ out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, dec_start_token, dec_end_token)
337
+ out = self.linear(out)
338
+ return out