yash commited on
Commit
4d0d76c
·
1 Parent(s): 5851e94

first commit

Browse files
Files changed (3) hide show
  1. app.py +148 -0
  2. requirements.txt +93 -0
  3. transformer.py +305 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.onnx
3
+ from transformer import Transformer
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+
11
+ # Generated this by filtering Appendix code
12
+ START_TOKEN = '<START>'
13
+ PADDING_TOKEN = '<PADDING>'
14
+ END_TOKEN = '<END>'
15
+
16
+
17
+ english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
18
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
19
+ ':', '<', '=', '>', '?', '@',
20
+ '[', '\\', ']', '^', '_', '`',
21
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
22
+ 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
23
+ 'y', 'z',
24
+ '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN]
25
+
26
+
27
+ gujarati_vocabulary = [
28
+ START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/',
29
+ '૦', '૧', '૨', '૩', '૪', '૫', '૬', '૭', '૮', '૯',
30
+ ':', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
31
+ 'અ', 'આ', 'ઇ', 'ઈ', 'ઉ', 'ઊ', 'ઋ', 'એ', 'ઐ', 'ઓ', 'ઔ',
32
+ 'ક', 'ખ', 'ગ', 'ઘ', 'ઙ', 'ચ', 'છ', 'જ', 'ઝ', 'ઞ',
33
+ 'ટ', 'ઠ', 'ડ', 'ઢ', 'ણ', 'ત', 'થ', 'દ', 'ધ', 'ન',
34
+ 'પ', 'ફ', 'બ', 'ભ', 'મ', 'ય', 'ર', 'લ', 'વ', 'શ',
35
+ 'ષ', 'સ', 'હ', 'ળ', 'ક્ષ', 'જ્ઞ', 'ં', 'ઃ', 'ઁ', 'ા',
36
+ 'િ', 'ી', 'ુ', 'ૂ', 'ે', 'ૈ', 'ો', 'ૌ', '્', 'ૐ',
37
+ '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN
38
+ ]
39
+
40
+ index_to_gujarati = {k:v for k,v in enumerate(gujarati_vocabulary)}
41
+ gujarati_to_index = {v:k for k,v in enumerate(gujarati_vocabulary)}
42
+ index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
43
+ english_to_index = {v:k for k,v in enumerate(english_vocabulary)}
44
+
45
+ d_model = 512
46
+ # batch_size = 64
47
+ ffn_hidden = 2048
48
+ num_heads = 8
49
+ drop_prob = 0.1
50
+ num_layers = 6
51
+ max_sequence_length = 200
52
+ kn_vocab_size = len(gujarati_vocabulary)
53
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
54
+
55
+ transformer = Transformer(d_model,
56
+ ffn_hidden,
57
+ num_heads,
58
+ drop_prob,
59
+ num_layers,
60
+ max_sequence_length,
61
+ kn_vocab_size,
62
+ english_to_index,
63
+ gujarati_to_index,
64
+ START_TOKEN,
65
+ END_TOKEN,
66
+ PADDING_TOKEN)
67
+
68
+ model_file = hf_hub_download(repo_id="yashAI007/English_to_Gujarati_Translation", filename="model.pth")
69
+ model = torch.load(model_file)
70
+ transformer.load_state_dict(model['model_state_dict'])
71
+ transformer.to(device)
72
+ transformer.eval()
73
+
74
+
75
+ NEG_INFTY = -1e9
76
+
77
+ def create_masks(eng_batch, kn_batch):
78
+ num_sentences = len(eng_batch)
79
+ look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
80
+ look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
81
+ encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
82
+ decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
83
+ decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
84
+
85
+ for idx in range(num_sentences):
86
+ eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
87
+ eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
88
+ kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
89
+ encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
90
+ encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
91
+ decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
92
+ decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
93
+ decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
94
+ decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True
95
+
96
+ encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
97
+ decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
98
+ decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
99
+ return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask
100
+
101
+ transformer.eval()
102
+ def translate(eng_sentence):
103
+ eng_sentence = (eng_sentence.lower(),)
104
+ kn_sentence = ("",)
105
+ for word_counter in range(max_sequence_length):
106
+ encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
107
+ predictions = transformer(eng_sentence,
108
+ kn_sentence,
109
+ encoder_self_attention_mask.to(device),
110
+ decoder_self_attention_mask.to(device),
111
+ decoder_cross_attention_mask.to(device),
112
+ enc_start_token=False,
113
+ enc_end_token=False,
114
+ dec_start_token=True,
115
+ dec_end_token=False)
116
+ next_token_prob_distribution = predictions[0][word_counter]
117
+ next_token_index = torch.argmax(next_token_prob_distribution).item()
118
+ next_token = index_to_gujarati[next_token_index]
119
+ kn_sentence = (kn_sentence[0] + next_token, )
120
+ if next_token == END_TOKEN:
121
+ break
122
+ return kn_sentence[0][:-5]
123
+
124
+ examples = [
125
+ ["Hello, how are you?"],
126
+ ["What is your name?"],
127
+ ["I love programming."],
128
+ ["This is a beautiful day."],
129
+ ["Can you help me with this?"],
130
+ ["What time is it?"],
131
+ ["I am learning data science."],
132
+ ["Where is the nearest bus stop?"],
133
+ ["I enjoy reading books."],
134
+ ["Thank you for your help."]
135
+ ]
136
+
137
+ description = "This tool translates English sentences into Gujarati. Please enter your text above to get started!"
138
+
139
+ iface = gr.Interface(fn=translate,
140
+ inputs="text",
141
+ outputs="text",
142
+ title="English to Gujarati Translation",
143
+ examples=examples,
144
+ description=description,
145
+ )
146
+
147
+ if __name__ == "__main__":
148
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ certifi==2024.7.4
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ contourpy==1.2.1
10
+ cycler==0.12.1
11
+ dnspython==2.6.1
12
+ email_validator==2.2.0
13
+ exceptiongroup==1.2.2
14
+ fastapi==0.111.1
15
+ fastapi-cli==0.0.4
16
+ ffmpy==0.3.2
17
+ filelock==3.15.4
18
+ fonttools==4.53.1
19
+ fsspec==2024.6.1
20
+ gradio==4.38.1
21
+ gradio_client==1.1.0
22
+ h11==0.14.0
23
+ httpcore==1.0.5
24
+ httptools==0.6.1
25
+ httpx==0.27.0
26
+ huggingface-hub==0.23.5
27
+ idna==3.7
28
+ importlib_resources==6.4.0
29
+ Jinja2==3.1.4
30
+ jsonschema==4.23.0
31
+ jsonschema-specifications==2023.12.1
32
+ kiwisolver==1.4.5
33
+ markdown-it-py==3.0.0
34
+ MarkupSafe==2.1.5
35
+ matplotlib==3.9.1
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ networkx==3.3
39
+ numpy==1.26.4
40
+ nvidia-cublas-cu12==12.1.3.1
41
+ nvidia-cuda-cupti-cu12==12.1.105
42
+ nvidia-cuda-nvrtc-cu12==12.1.105
43
+ nvidia-cuda-runtime-cu12==12.1.105
44
+ nvidia-cudnn-cu12==8.9.2.26
45
+ nvidia-cufft-cu12==11.0.2.54
46
+ nvidia-curand-cu12==10.3.2.106
47
+ nvidia-cusolver-cu12==11.4.5.107
48
+ nvidia-cusparse-cu12==12.1.0.106
49
+ nvidia-nccl-cu12==2.20.5
50
+ nvidia-nvjitlink-cu12==12.5.82
51
+ nvidia-nvtx-cu12==12.1.105
52
+ orjson==3.10.6
53
+ packaging==24.1
54
+ pandas==2.2.2
55
+ pillow==10.4.0
56
+ pydantic==2.8.2
57
+ pydantic_core==2.20.1
58
+ pydub==0.25.1
59
+ Pygments==2.18.0
60
+ pyparsing==3.1.2
61
+ python-dateutil==2.9.0.post0
62
+ python-dotenv==1.0.1
63
+ python-multipart==0.0.9
64
+ pytz==2024.1
65
+ PyYAML==6.0.1
66
+ referencing==0.35.1
67
+ regex==2024.5.15
68
+ requests==2.32.3
69
+ rich==13.7.1
70
+ rpds-py==0.19.0
71
+ ruff==0.5.2
72
+ safetensors==0.4.3
73
+ semantic-version==2.10.0
74
+ shellingham==1.5.4
75
+ six==1.16.0
76
+ sniffio==1.3.1
77
+ starlette==0.37.2
78
+ sympy==1.13.0
79
+ tokenizers==0.19.1
80
+ tomlkit==0.12.0
81
+ toolz==0.12.1
82
+ torch==2.3.1
83
+ tqdm==4.66.4
84
+ transformers==4.42.4
85
+ triton==2.3.1
86
+ typer==0.12.3
87
+ typing_extensions==4.12.2
88
+ tzdata==2024.1
89
+ urllib3==2.2.2
90
+ uvicorn==0.30.1
91
+ uvloop==0.19.0
92
+ watchfiles==0.22.0
93
+ websockets==11.0.3
transformer.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ def get_device():
8
+ return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
+ # return torch.device('cpu')
10
+
11
+ def scaled_dot_product(q, k, v, mask=None):
12
+ d_k = q.size()[-1]
13
+ scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
14
+ if mask is not None:
15
+ scaled = scaled.permute(1, 0, 2, 3) + mask
16
+ scaled = scaled.permute(1, 0, 2, 3)
17
+ attention = F.softmax(scaled, dim=-1)
18
+ values = torch.matmul(attention, v)
19
+ return values, attention
20
+
21
+ class PositionalEncoding(nn.Module):
22
+ def __init__(self, d_model, max_sequence_length):
23
+ super().__init__()
24
+ self.max_sequence_length = max_sequence_length
25
+ self.d_model = d_model
26
+
27
+ def forward(self):
28
+ even_i = torch.arange(0, self.d_model, 2).float()
29
+ denominator = torch.pow(10000, even_i/self.d_model)
30
+ position = (torch.arange(self.max_sequence_length)
31
+ .reshape(self.max_sequence_length, 1))
32
+ even_PE = torch.sin(position / denominator)
33
+ odd_PE = torch.cos(position / denominator)
34
+ stacked = torch.stack([even_PE, odd_PE], dim=2)
35
+ PE = torch.flatten(stacked, start_dim=1, end_dim=2)
36
+ return PE
37
+
38
+ class SentenceEmbedding(nn.Module):
39
+ "For a given sentence, create an embedding"
40
+ def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
41
+ super().__init__()
42
+ self.vocab_size = len(language_to_index)
43
+ self.max_sequence_length = max_sequence_length
44
+ self.embedding = nn.Embedding(self.vocab_size, d_model)
45
+ self.language_to_index = language_to_index
46
+ self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
47
+ self.dropout = nn.Dropout(p=0.1)
48
+ self.START_TOKEN = START_TOKEN
49
+ self.END_TOKEN = END_TOKEN
50
+ self.PADDING_TOKEN = PADDING_TOKEN
51
+
52
+ def batch_tokenize(self, batch, start_token, end_token):
53
+
54
+ def tokenize(sentence, start_token, end_token):
55
+ sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
56
+ if start_token:
57
+ sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
58
+ if end_token:
59
+ sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
60
+ for _ in range(len(sentence_word_indicies), self.max_sequence_length):
61
+ sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
62
+ return torch.tensor(sentence_word_indicies)
63
+
64
+ tokenized = []
65
+ for sentence_num in range(len(batch)):
66
+ tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
67
+ tokenized = torch.stack(tokenized)
68
+ return tokenized.to(get_device())
69
+
70
+ def forward(self, x, start_token, end_token): # sentence
71
+ x = self.batch_tokenize(x, start_token, end_token)
72
+ x = self.embedding(x)
73
+ pos = self.position_encoder().to(get_device())
74
+ x = self.dropout(x + pos)
75
+ return x
76
+
77
+
78
+ class MultiHeadAttention(nn.Module):
79
+ def __init__(self, d_model, num_heads):
80
+ super().__init__()
81
+ self.d_model = d_model
82
+ self.num_heads = num_heads
83
+ self.head_dim = d_model // num_heads
84
+ self.qkv_layer = nn.Linear(d_model , 3 * d_model)
85
+ self.linear_layer = nn.Linear(d_model, d_model)
86
+
87
+ def forward(self, x, mask):
88
+ batch_size, sequence_length, d_model = x.size()
89
+ qkv = self.qkv_layer(x)
90
+ qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
91
+ qkv = qkv.permute(0, 2, 1, 3)
92
+ q, k, v = qkv.chunk(3, dim=-1)
93
+ values, attention = scaled_dot_product(q, k, v, mask)
94
+ values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
95
+ out = self.linear_layer(values)
96
+ return out
97
+
98
+
99
+ class LayerNormalization(nn.Module):
100
+ def __init__(self, parameters_shape, eps=1e-5):
101
+ super().__init__()
102
+ self.parameters_shape=parameters_shape
103
+ self.eps=eps
104
+ self.gamma = nn.Parameter(torch.ones(parameters_shape))
105
+ self.beta = nn.Parameter(torch.zeros(parameters_shape))
106
+
107
+ def forward(self, inputs):
108
+ dims = [-(i + 1) for i in range(len(self.parameters_shape))]
109
+ mean = inputs.mean(dim=dims, keepdim=True)
110
+ var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
111
+ std = (var + self.eps).sqrt()
112
+ y = (inputs - mean) / std
113
+ out = self.gamma * y + self.beta
114
+ return out
115
+
116
+
117
+ class PositionwiseFeedForward(nn.Module):
118
+ def __init__(self, d_model, hidden, drop_prob=0.1):
119
+ super(PositionwiseFeedForward, self).__init__()
120
+ self.linear1 = nn.Linear(d_model, hidden)
121
+ self.linear2 = nn.Linear(hidden, d_model)
122
+ self.relu = nn.ReLU()
123
+ self.dropout = nn.Dropout(p=drop_prob)
124
+
125
+ def forward(self, x):
126
+ x = self.linear1(x)
127
+ x = self.relu(x)
128
+ x = self.dropout(x)
129
+ x = self.linear2(x)
130
+ return x
131
+
132
+
133
+ class EncoderLayer(nn.Module):
134
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
135
+ super(EncoderLayer, self).__init__()
136
+ self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
137
+ self.norm1 = LayerNormalization(parameters_shape=[d_model])
138
+ self.dropout1 = nn.Dropout(p=drop_prob)
139
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
140
+ self.norm2 = LayerNormalization(parameters_shape=[d_model])
141
+ self.dropout2 = nn.Dropout(p=drop_prob)
142
+
143
+ def forward(self, x, self_attention_mask):
144
+ residual_x = x.clone()
145
+ x = self.attention(x, mask=self_attention_mask)
146
+ x = self.dropout1(x)
147
+ x = self.norm1(x + residual_x)
148
+ residual_x = x.clone()
149
+ x = self.ffn(x)
150
+ x = self.dropout2(x)
151
+ x = self.norm2(x + residual_x)
152
+ return x
153
+
154
+ class SequentialEncoder(nn.Sequential):
155
+ def forward(self, *inputs):
156
+ x, self_attention_mask = inputs
157
+ for module in self._modules.values():
158
+ x = module(x, self_attention_mask)
159
+ return x
160
+
161
+ class Encoder(nn.Module):
162
+ def __init__(self,
163
+ d_model,
164
+ ffn_hidden,
165
+ num_heads,
166
+ drop_prob,
167
+ num_layers,
168
+ max_sequence_length,
169
+ language_to_index,
170
+ START_TOKEN,
171
+ END_TOKEN,
172
+ PADDING_TOKEN):
173
+ super().__init__()
174
+ self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
175
+ self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
176
+ for _ in range(num_layers)])
177
+
178
+ def forward(self, x, self_attention_mask, start_token, end_token):
179
+ x = self.sentence_embedding(x, start_token, end_token)
180
+ x = self.layers(x, self_attention_mask)
181
+ return x
182
+
183
+
184
+ class MultiHeadCrossAttention(nn.Module):
185
+ def __init__(self, d_model, num_heads):
186
+ super().__init__()
187
+ self.d_model = d_model
188
+ self.num_heads = num_heads
189
+ self.head_dim = d_model // num_heads
190
+ self.kv_layer = nn.Linear(d_model , 2 * d_model)
191
+ self.q_layer = nn.Linear(d_model , d_model)
192
+ self.linear_layer = nn.Linear(d_model, d_model)
193
+
194
+ def forward(self, x, y, mask):
195
+ batch_size, sequence_length, d_model = x.size() # in practice, this is the same for both languages...so we can technically combine with normal attention
196
+ kv = self.kv_layer(x)
197
+ q = self.q_layer(y)
198
+ kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
199
+ q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
200
+ kv = kv.permute(0, 2, 1, 3)
201
+ q = q.permute(0, 2, 1, 3)
202
+ k, v = kv.chunk(2, dim=-1)
203
+ values, attention = scaled_dot_product(q, k, v, mask) # We don't need the mask for cross attention, removing in outer function!
204
+ values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
205
+ out = self.linear_layer(values)
206
+ return out
207
+
208
+
209
+ class DecoderLayer(nn.Module):
210
+ def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
211
+ super(DecoderLayer, self).__init__()
212
+ self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
213
+ self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
214
+ self.dropout1 = nn.Dropout(p=drop_prob)
215
+
216
+ self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
217
+ self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
218
+ self.dropout2 = nn.Dropout(p=drop_prob)
219
+
220
+ self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
221
+ self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
222
+ self.dropout3 = nn.Dropout(p=drop_prob)
223
+
224
+ def forward(self, x, y, self_attention_mask, cross_attention_mask):
225
+ _y = y.clone()
226
+ y = self.self_attention(y, mask=self_attention_mask)
227
+ y = self.dropout1(y)
228
+ y = self.layer_norm1(y + _y)
229
+
230
+ _y = y.clone()
231
+ y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask)
232
+ y = self.dropout2(y)
233
+ y = self.layer_norm2(y + _y)
234
+
235
+ _y = y.clone()
236
+ y = self.ffn(y)
237
+ y = self.dropout3(y)
238
+ y = self.layer_norm3(y + _y)
239
+ return y
240
+
241
+
242
+ class SequentialDecoder(nn.Sequential):
243
+ def forward(self, *inputs):
244
+ x, y, self_attention_mask, cross_attention_mask = inputs
245
+ for module in self._modules.values():
246
+ y = module(x, y, self_attention_mask, cross_attention_mask)
247
+ return y
248
+
249
+ class Decoder(nn.Module):
250
+ def __init__(self,
251
+ d_model,
252
+ ffn_hidden,
253
+ num_heads,
254
+ drop_prob,
255
+ num_layers,
256
+ max_sequence_length,
257
+ language_to_index,
258
+ START_TOKEN,
259
+ END_TOKEN,
260
+ PADDING_TOKEN):
261
+ super().__init__()
262
+ self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
263
+ self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])
264
+
265
+ def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
266
+ y = self.sentence_embedding(y, start_token, end_token)
267
+ y = self.layers(x, y, self_attention_mask, cross_attention_mask)
268
+ return y
269
+
270
+
271
+ class Transformer(nn.Module):
272
+ def __init__(self,
273
+ d_model,
274
+ ffn_hidden,
275
+ num_heads,
276
+ drop_prob,
277
+ num_layers,
278
+ max_sequence_length,
279
+ kn_vocab_size,
280
+ english_to_index,
281
+ kannada_to_index,
282
+ START_TOKEN,
283
+ END_TOKEN,
284
+ PADDING_TOKEN
285
+ ):
286
+ super().__init__()
287
+ self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, english_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
288
+ self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, kannada_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
289
+ self.linear = nn.Linear(d_model, kn_vocab_size)
290
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
291
+
292
+ def forward(self,
293
+ x,
294
+ y,
295
+ encoder_self_attention_mask=None,
296
+ decoder_self_attention_mask=None,
297
+ decoder_cross_attention_mask=None,
298
+ enc_start_token=False,
299
+ enc_end_token=False,
300
+ dec_start_token=False,
301
+ dec_end_token=False):
302
+ x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
303
+ out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token)
304
+ out = self.linear(out)
305
+ return out