File size: 176 Bytes
426ffb5
 
 
 
 
 
1
2
3
4
5
6
7
import torch.nn as nn


class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)