class ProtoNet(nn.Module): def __init__(self, in_channels, hidden_channel): super(ProtoNet, self).__init__() # 임베딩 네트워크 self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(hidden_channel, in_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() self.flatten = nn.Flatten() self.softmax = nn.LogSoftmax(dim=1) # __init__(): def prototyping(self, prototypes): self.prototypes = prototypes def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor: # 거리 함수 assert self.prototypes is not None, "Prototypes must be set before calling cdist." assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match." if metric == "euclidean": dists = torch.cdist(x, self.prototypes, p=2) # L2 distance elif metric == "cosine": dists = 1 - F.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity else: raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.") return dists # cdist() def forward(self, x, metric="euclidean"): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.relu(x) x = self.conv3(x) x = self.relu(x) x = self.flatten(x) x = self.cdist(x, metric=metric) return self.softmax(-x) # forward # ProtoNet