init: safetensor, bin
Browse files- config.json +8 -0
- config.py +0 -8
- src/model/ProtoNet.py → model.py +7 -9
- prototypical_network.bin +3 -0
- prototypical_network.safetensors +3 -0
- run.py +0 -41
- src/FewShotEpisoder.py +0 -84
- src/evaluate.py +0 -48
- src/train.py +0 -73
config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sate": "model.state_dict()",
|
3 |
+
"FRAMEWORK": [5, 5, 2],
|
4 |
+
"MODEL_CONFIG": [3, 26, 3],
|
5 |
+
"HYPER_PARAMETERS": {"lr": 0.0001, "weight_decay": 0.0001},
|
6 |
+
"TRANSFORM": "transform",
|
7 |
+
"METRIC": "euclidean"
|
8 |
+
}
|
config.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
HYPERPARAMETER_CONFIG = {
|
2 |
-
"lr": 0.0001,
|
3 |
-
"weight_decay": 0.0001
|
4 |
-
} # HYPERPARAMETER_CONFIG
|
5 |
-
TRAINING_CONFIG = {
|
6 |
-
"iters": 10,
|
7 |
-
"epochs": 30,
|
8 |
-
} # TRAINING_CONFIG
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model/ProtoNet.py → model.py
RENAMED
@@ -1,13 +1,10 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
import torch.nn.functional as torch_f
|
4 |
-
|
5 |
class ProtoNet(nn.Module):
|
6 |
-
def __init__(self, in_channels
|
7 |
super(ProtoNet, self).__init__()
|
|
|
8 |
self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
|
9 |
self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
|
10 |
-
self.conv3 = nn.Conv2d(hidden_channel,
|
11 |
self.relu = nn.ReLU()
|
12 |
self.flatten = nn.Flatten()
|
13 |
self.softmax = nn.LogSoftmax(dim=1)
|
@@ -16,18 +13,19 @@ class ProtoNet(nn.Module):
|
|
16 |
def prototyping(self, prototypes): self.prototypes = prototypes
|
17 |
|
18 |
def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor:
|
|
|
19 |
assert self.prototypes is not None, "Prototypes must be set before calling cdist."
|
20 |
assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match."
|
21 |
if metric == "euclidean":
|
22 |
dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
|
23 |
elif metric == "cosine":
|
24 |
-
dists = 1 -
|
25 |
else:
|
26 |
raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
|
27 |
return dists
|
28 |
# cdist()
|
29 |
|
30 |
-
def forward(self, x):
|
31 |
x = self.conv1(x)
|
32 |
x = self.relu(x)
|
33 |
x = self.conv2(x)
|
@@ -35,7 +33,7 @@ class ProtoNet(nn.Module):
|
|
35 |
x = self.conv3(x)
|
36 |
x = self.relu(x)
|
37 |
x = self.flatten(x)
|
38 |
-
x = self.cdist(x, metric=
|
39 |
return self.softmax(-x)
|
40 |
# forward
|
41 |
# ProtoNet
|
|
|
|
|
|
|
|
|
|
|
1 |
class ProtoNet(nn.Module):
|
2 |
+
def __init__(self, in_channels, hidden_channel):
|
3 |
super(ProtoNet, self).__init__()
|
4 |
+
# 임베딩 네트워크
|
5 |
self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
|
6 |
self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
|
7 |
+
self.conv3 = nn.Conv2d(hidden_channel, in_channels, kernel_size=3, stride=1, padding=1)
|
8 |
self.relu = nn.ReLU()
|
9 |
self.flatten = nn.Flatten()
|
10 |
self.softmax = nn.LogSoftmax(dim=1)
|
|
|
13 |
def prototyping(self, prototypes): self.prototypes = prototypes
|
14 |
|
15 |
def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor:
|
16 |
+
# 거리 함수
|
17 |
assert self.prototypes is not None, "Prototypes must be set before calling cdist."
|
18 |
assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match."
|
19 |
if metric == "euclidean":
|
20 |
dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
|
21 |
elif metric == "cosine":
|
22 |
+
dists = 1 - F.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
|
23 |
else:
|
24 |
raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
|
25 |
return dists
|
26 |
# cdist()
|
27 |
|
28 |
+
def forward(self, x, metric="euclidean"):
|
29 |
x = self.conv1(x)
|
30 |
x = self.relu(x)
|
31 |
x = self.conv2(x)
|
|
|
33 |
x = self.conv3(x)
|
34 |
x = self.relu(x)
|
35 |
x = self.flatten(x)
|
36 |
+
x = self.cdist(x, metric=metric)
|
37 |
return self.softmax(-x)
|
38 |
# forward
|
39 |
# ProtoNet
|
prototypical_network.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1207dad08600ea1a3f1a622d2469a3208ebc8dd007c3b6051b256d7df5103f03
|
3 |
+
size 34442
|
prototypical_network.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe7d84d01a2cbd6369d78422355940f7f902eb78e9359f764d8ebc7c77eacad8
|
3 |
+
size 30620
|
run.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torchvision as tv
|
3 |
-
from src.train import train
|
4 |
-
from src.evaluate import evaluate
|
5 |
-
|
6 |
-
def main():
|
7 |
-
# eval(default)
|
8 |
-
parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
|
9 |
-
parser.add_argument("--model", type=str, help="path of your model")
|
10 |
-
parser.add_argument("--dataset", type=str, help="path of your dataset")
|
11 |
-
|
12 |
-
# train
|
13 |
-
subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
|
14 |
-
parser_train = subparser.add_parser("train", help="train your model")
|
15 |
-
parser_train.add_argument("--dataset", type=str, help="path to your dataset")
|
16 |
-
parser_train.add_argument("--save_to", type=str, help="path to save your model")
|
17 |
-
parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
|
18 |
-
parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
|
19 |
-
parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
|
20 |
-
parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
|
21 |
-
parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
|
22 |
-
parser_train.set_defaults(func=lambda kwargs: train(
|
23 |
-
DATASET=kwargs.dataset,
|
24 |
-
SAVE_TO=kwargs.save_to,
|
25 |
-
N_WAY=kwargs.n_way,
|
26 |
-
K_SHOT=kwargs.k_shot,
|
27 |
-
N_QUERY=kwargs.n_query)
|
28 |
-
) # parser_train.set_defaults()
|
29 |
-
|
30 |
-
# download dataset
|
31 |
-
parser_download = subparser.add_parser("download", help="download dataset")
|
32 |
-
parser_download.add_argument("--path", type=str, help="path to download dataset")
|
33 |
-
parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
|
34 |
-
|
35 |
-
# parse logic
|
36 |
-
args = parser.parse_args()
|
37 |
-
if hasattr(args, 'func'): args.func(args)
|
38 |
-
else: evaluate(MODEL=args.model, DATASET=args.dataset)
|
39 |
-
# main():
|
40 |
-
|
41 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/FewShotEpisoder.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
import typing
|
3 |
-
import torch
|
4 |
-
from torch.utils.data import Dataset
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
class FewShotDataset(Dataset):
|
8 |
-
""" A custom Dataset class for Few-Shot Learning tasks.
|
9 |
-
This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
|
10 |
-
def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"):
|
11 |
-
""" Args:
|
12 |
-
dataset (list): List of (feature, label) pairs.
|
13 |
-
indices (list): List of indices to be used for the dataset.
|
14 |
-
transform (callable): Transform to be applied to the features.
|
15 |
-
mode (str): Mode of operation, either "support" or "query". Default is "support". """
|
16 |
-
assert mode in ["support", "query"], "Invalid mode. Must be either 'support' or 'query'." # check if mode is valid
|
17 |
-
assert dataset and indices and classes is not None, "Dataset or indices cannot be None." # check if dataset is not None
|
18 |
-
|
19 |
-
self.dataset, self.indices, self.classes = dataset, indices, classes
|
20 |
-
self.mode, self.transform = mode, transform
|
21 |
-
# __init__():
|
22 |
-
|
23 |
-
def __getitem__(self, index: int):
|
24 |
-
""" Returns a sample from the dataset at the given index.
|
25 |
-
Args: index of the sample to be retrieved.
|
26 |
-
Returns: tuple of the transformed feature and the label. """
|
27 |
-
if index >= len(self.indices):
|
28 |
-
raise IndexError("Index out of bounds") # check if index is out of bounds
|
29 |
-
feature, label = self.dataset[self.indices[index]]
|
30 |
-
# apply transformation
|
31 |
-
feature = self.transform(feature)
|
32 |
-
if self.mode == "query": # if mode is query, convert label to one-hot vector
|
33 |
-
label = F.one_hot(torch.tensor(self.classes.index(label)), num_classes=len(self.classes)).float()
|
34 |
-
return feature, label
|
35 |
-
# __getitem__():
|
36 |
-
|
37 |
-
def __len__(self): return len(self.indices)
|
38 |
-
# FSLDataset()
|
39 |
-
|
40 |
-
class FewShotEpisoder:
|
41 |
-
""" A class to generate episodes for Few-Shot Learning.
|
42 |
-
Each episode consists of a support set and a query set. """
|
43 |
-
def __init__(self, dataset, classes: list, k_shot: int, n_query: int, transform: typing.Callable):
|
44 |
-
""" Args:
|
45 |
-
dataset (Dataset): The base dataset to generate episodes from.
|
46 |
-
k_shot (int): Number of support samples per class.
|
47 |
-
n_query (int): Number of query samples per class.
|
48 |
-
transform (callable): Transform to be applied to the features. """
|
49 |
-
assert k_shot > 0 and n_query > 0, "k_shot and n_query must be greater than 0." # check if k_shot and n_query are valid
|
50 |
-
|
51 |
-
self.k_shot, self.n_query, self.classes = k_shot, n_query, classes
|
52 |
-
self.dataset, self.transform = dataset, transform
|
53 |
-
self.indices_c = self.get_class_indices()
|
54 |
-
# __init__()
|
55 |
-
|
56 |
-
def get_class_indices(self) -> dict:
|
57 |
-
""" Initialize the class indices for the dataset.
|
58 |
-
Returns: tuple of Number of classes and a list of indices grouped by class. """
|
59 |
-
indices_c = {label: [] for label in range(self.classes.__len__())}
|
60 |
-
for index, (_, label) in enumerate(self.dataset):
|
61 |
-
if label in self.classes: indices_c[self.classes.index(label)].append(index)
|
62 |
-
for label, _indices_c in indices_c.items():
|
63 |
-
indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query)
|
64 |
-
return indices_c
|
65 |
-
# get_indices():
|
66 |
-
|
67 |
-
def get_episode(self) -> tuple: # select classes using list of chosen indexes
|
68 |
-
""" Generate an episode consisting of a support set and a query set.
|
69 |
-
Returns: tuple of A FewShotDataset for the support set and a FewShotDataset for the query set. """
|
70 |
-
# get support and query examples
|
71 |
-
support_examples, query_examples = [], []
|
72 |
-
for class_label in range(self.classes.__len__()):
|
73 |
-
if len(self.indices_c[class_label]) < self.k_shot + self.n_query: continue # skip class if it doesn't have enough samples
|
74 |
-
selected_indices = random.sample(self.indices_c[class_label], self.k_shot + self.n_query)
|
75 |
-
support_examples.extend(selected_indices[:self.k_shot])
|
76 |
-
query_examples.extend(selected_indices)
|
77 |
-
|
78 |
-
# init support and query datasets
|
79 |
-
support_set = FewShotDataset(self.dataset, support_examples, self.classes, self.transform, "support")
|
80 |
-
query_set = FewShotDataset(self.dataset, query_examples, self.classes, self.transform, "query")
|
81 |
-
|
82 |
-
return support_set, query_set
|
83 |
-
# get_episode()
|
84 |
-
# Episoder()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/evaluate.py
DELETED
@@ -1,48 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
import torch
|
3 |
-
from torch import nn
|
4 |
-
from torch.utils.data import DataLoader
|
5 |
-
import torchvision as tv
|
6 |
-
from src.model.ProtoNet import ProtoNet
|
7 |
-
from src.FewShotEpisoder import FewShotEpisoder
|
8 |
-
|
9 |
-
def evaluate(MODEL: str, DATASET: str):
|
10 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # select device
|
11 |
-
|
12 |
-
# load model
|
13 |
-
data = torch.load(MODEL)
|
14 |
-
n_way, k_shot, n_query = data["framework"]
|
15 |
-
|
16 |
-
# load model
|
17 |
-
model = ProtoNet(*data["model_config"].values()).to(device)
|
18 |
-
model.load_state_dict(data["state"])
|
19 |
-
model.eval()
|
20 |
-
|
21 |
-
# create FSL episode generator
|
22 |
-
imageset = tv.datasets.ImageFolder(root=DATASET)
|
23 |
-
unseen_classes = [_ for _ in random.sample(list(imageset.class_to_idx.values()), n_way)]
|
24 |
-
episoder = FewShotEpisoder(imageset, unseen_classes, k_shot, n_query, data["transform"])
|
25 |
-
|
26 |
-
# compute prototype from support examples
|
27 |
-
support_set, query_set = episoder.get_episode()
|
28 |
-
prototypes = list()
|
29 |
-
embedded_features_list = [[] for _ in range(len(support_set.classes))]
|
30 |
-
for embedded_feature, label in support_set: embedded_features_list[unseen_classes.index(label)].append(embedded_feature)
|
31 |
-
for embedded_features in embedded_features_list:
|
32 |
-
class_prototype = torch.stack(embedded_features).mean(dim=0)
|
33 |
-
prototypes.append(class_prototype.flatten())
|
34 |
-
prototypes = torch.stack(prototypes)
|
35 |
-
model.prototyping(prototypes)
|
36 |
-
|
37 |
-
# eval model
|
38 |
-
total_loss, count, n_problem = 0., 0, len(query_set)
|
39 |
-
criterion = nn.CrossEntropyLoss()
|
40 |
-
for feature, label in DataLoader(query_set, shuffle=True):
|
41 |
-
pred = model.forward(feature)
|
42 |
-
loss = criterion(pred, label)
|
43 |
-
total_loss += loss.item()
|
44 |
-
if torch.argmax(pred) == torch.argmax(label): count += 1
|
45 |
-
print(f"seen classes: {data['seen_classes']}\nunseen classes: {unseen_classes}\naccuracy: {count / n_problem:.4f}({count}/{n_problem})")
|
46 |
-
# main()
|
47 |
-
|
48 |
-
if __name__ == "__main__": evaluate("./model/model.pth", "../data/omniglot-py/images_background/Futurama")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/train.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
import torch.cuda
|
3 |
-
import torchvision as tv
|
4 |
-
from torch import nn
|
5 |
-
from tqdm import tqdm
|
6 |
-
from torch.utils.data import DataLoader
|
7 |
-
from src.FewShotEpisoder import FewShotEpisoder
|
8 |
-
from src.model.ProtoNet import ProtoNet
|
9 |
-
from config import TRAINING_CONFIG, HYPERPARAMETER_CONFIG
|
10 |
-
|
11 |
-
def train(DATASET:str, SAVE_TO:str, N_WAY:int, K_SHOT:int, N_QUERY:int, ITERS=TRAINING_CONFIG["iters"], EPOCHS=TRAINING_CONFIG["epochs"]):
|
12 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device
|
13 |
-
|
14 |
-
# define transform
|
15 |
-
transform = tv.transforms.Compose([
|
16 |
-
tv.transforms.Resize((224, 224)),
|
17 |
-
tv.transforms.ToTensor(),
|
18 |
-
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
19 |
-
]) # transform
|
20 |
-
|
21 |
-
# init episode generator
|
22 |
-
imageset = tv.datasets.ImageFolder(root=DATASET)
|
23 |
-
seen_classes = [_ for _ in random.sample(list(imageset.class_to_idx.values()), N_WAY)]
|
24 |
-
episoder = FewShotEpisoder(imageset, seen_classes, K_SHOT, N_QUERY, transform)
|
25 |
-
|
26 |
-
# init model
|
27 |
-
model_config = {"in_channels": 3, "hidden_channels": 26, "output_channels": 3}
|
28 |
-
model = ProtoNet(*model_config.values()).to(device)
|
29 |
-
optim = torch.optim.Adam(model.parameters(), lr=HYPERPARAMETER_CONFIG["lr"], weight_decay=HYPERPARAMETER_CONFIG["weight_decay"])
|
30 |
-
criterion = nn.CrossEntropyLoss()
|
31 |
-
|
32 |
-
progress_bar, whole_loss = tqdm(range(EPOCHS)), float()
|
33 |
-
support_set, query_set = episoder.get_episode()
|
34 |
-
for _ in progress_bar:
|
35 |
-
# STAGE1: compute prototype from support examples
|
36 |
-
prototypes = list()
|
37 |
-
embedded_features_list = [[] for _ in range(len(support_set.classes))]
|
38 |
-
for embedded_feature, label in support_set: embedded_features_list[seen_classes.index(label)].append(embedded_feature)
|
39 |
-
for embedded_features in embedded_features_list:
|
40 |
-
class_prototype = torch.stack(embedded_features).mean(dim=0)
|
41 |
-
prototypes.append(class_prototype.flatten())
|
42 |
-
# for
|
43 |
-
prototypes = torch.stack(prototypes)
|
44 |
-
model.prototyping(prototypes)
|
45 |
-
# STAGE2: update parameters form loss associated with prototypes
|
46 |
-
epochs_loss = 0.0
|
47 |
-
for _ in range(ITERS):
|
48 |
-
iter_loss = 0.0
|
49 |
-
for feature, label in DataLoader(query_set, shuffle=True):
|
50 |
-
loss = criterion(model.forward(feature), label)
|
51 |
-
iter_loss += loss.item()
|
52 |
-
optim.zero_grad()
|
53 |
-
loss.backward()
|
54 |
-
optim.step()
|
55 |
-
epochs_loss += iter_loss / len(query_set)
|
56 |
-
# for # for
|
57 |
-
epochs_loss = epochs_loss / ITERS
|
58 |
-
progress_bar.set_postfix(loss=epochs_loss)
|
59 |
-
# for
|
60 |
-
|
61 |
-
# saving the model's parameters and the other data
|
62 |
-
features = {
|
63 |
-
"state": model.state_dict(),
|
64 |
-
"model_config": model_config,
|
65 |
-
"transform": transform,
|
66 |
-
"seen_classes": seen_classes,
|
67 |
-
"framework": (N_WAY, K_SHOT, N_QUERY)
|
68 |
-
} # features
|
69 |
-
torch.save(features, SAVE_TO)
|
70 |
-
print(f"model save to {SAVE_TO}")
|
71 |
-
# main()
|
72 |
-
|
73 |
-
if __name__ == "__main__": train("../data/omniglot-py/images_background/Futurama", "./model/model.pth", 5, 5, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|