lif31up's picture
init: safetensor, bin
9d58bb5
class ProtoNet(nn.Module):
def __init__(self, in_channels, hidden_channel):
super(ProtoNet, self).__init__()
# 임베딩 네트워크
self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(hidden_channel, in_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
self.softmax = nn.LogSoftmax(dim=1)
# __init__():
def prototyping(self, prototypes): self.prototypes = prototypes
def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor:
# 거리 함수
assert self.prototypes is not None, "Prototypes must be set before calling cdist."
assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match."
if metric == "euclidean":
dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
elif metric == "cosine":
dists = 1 - F.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
else:
raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
return dists
# cdist()
def forward(self, x, metric="euclidean"):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.relu(x)
x = self.flatten(x)
x = self.cdist(x, metric=metric)
return self.softmax(-x)
# forward
# ProtoNet