|
class ProtoNet(nn.Module): |
|
def __init__(self, in_channels, hidden_channel): |
|
super(ProtoNet, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1) |
|
self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1) |
|
self.conv3 = nn.Conv2d(hidden_channel, in_channels, kernel_size=3, stride=1, padding=1) |
|
self.relu = nn.ReLU() |
|
self.flatten = nn.Flatten() |
|
self.softmax = nn.LogSoftmax(dim=1) |
|
|
|
|
|
def prototyping(self, prototypes): self.prototypes = prototypes |
|
|
|
def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor: |
|
|
|
assert self.prototypes is not None, "Prototypes must be set before calling cdist." |
|
assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match." |
|
if metric == "euclidean": |
|
dists = torch.cdist(x, self.prototypes, p=2) |
|
elif metric == "cosine": |
|
dists = 1 - F.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) |
|
else: |
|
raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.") |
|
return dists |
|
|
|
|
|
def forward(self, x, metric="euclidean"): |
|
x = self.conv1(x) |
|
x = self.relu(x) |
|
x = self.conv2(x) |
|
x = self.relu(x) |
|
x = self.conv3(x) |
|
x = self.relu(x) |
|
x = self.flatten(x) |
|
x = self.cdist(x, metric=metric) |
|
return self.softmax(-x) |
|
|
|
|