File size: 157 Bytes
426ffb5
 
 
 
 
 
1
2
3
4
5
6
7
import torch.nn as nn


class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size=512):
        super().__init__(3, embed_size, padding_idx=0)