import torch.nn as nn class SegmentEmbedding(nn.Embedding): def __init__(self, embed_size=512): super().__init__(3, embed_size, padding_idx=0)