big update
Browse files- .gitignore +2 -1
- LICENSE +21 -0
- README.md +8 -15
- config.py +8 -0
- run.py +12 -15
- src/FewShotEpisoder.py +3 -1
- src/{eval.py → evaluate.py} +14 -14
- src/model/ProtoNet.py +4 -4
- src/train.py +32 -18
.gitignore
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# user-defined
|
2 |
-
data
|
3 |
*.pth
|
|
|
4 |
|
5 |
# pycharm
|
6 |
.idea
|
|
|
1 |
# user-defined
|
2 |
+
**/data
|
3 |
*.pth
|
4 |
+
**/__pycache__/
|
5 |
|
6 |
# pycharm
|
7 |
.idea
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 한명환
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,11 +1,3 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
datasets:
|
4 |
-
- dpdl-benchmark/omniglot
|
5 |
-
language:
|
6 |
-
- ko
|
7 |
-
pipeline_tag: image-classification
|
8 |
-
---
|
9 |
`torch` `torchvision` `tqdm`
|
10 |
|
11 |
This implementation is inspired by **"Prototypical Networks for Few-Shot Learning" (Snell et al., 2017)**.
|
@@ -17,7 +9,9 @@ This repository implements a Prototypical Network for few-shot image classificat
|
|
17 |
|
18 |
Few-shot learning aims to enable models to generalize to new classes with only a few labeled examples. Prototypical Networks achieve this by computing a prototype (mean embedding) for each class and classifying query samples based on their distances to these prototypes in the embedding space.
|
19 |
|
20 |
-
|
|
|
|
|
21 |
|
22 |
## Instruction
|
23 |
Organize your dataset into a structure compatible with PyTorch's ImageFolder:
|
@@ -39,21 +33,20 @@ Run the training script with desired parameters:
|
|
39 |
```
|
40 |
python run.py train --dataset_path path/to/your/dataset --save_to /path/to/save/model --n_way 5 --k_shot 2 --n_query 4 --epochs 1 --iters 4
|
41 |
```
|
42 |
-
* `
|
43 |
* `save_to`: path to save the trained model.
|
44 |
* `n_way`: number of classes in each episode.
|
45 |
* `k_shot`: Number of support samples per class.
|
46 |
* `n-_query`: Number of query samples per class.
|
47 |
-
|
48 |
-
|
49 |
|
50 |
### Evaluation
|
51 |
```
|
52 |
python run.py --path path/to/your/dataset --model path/to/saved/model.pth --n_way 5
|
53 |
```
|
54 |
-
* `
|
55 |
* `model`: Path to your model.
|
56 |
-
* `n_way`: Number of classes in each episode.
|
57 |
|
58 |
### Download Omniglot Dataset
|
59 |
```
|
@@ -68,4 +61,4 @@ Prototypical Networks are a powerful approach for **few-shot learning**, where t
|
|
68 |
* **Embedding Representation with CNN**: Each input image is passed through a convolutional encoder to obtain a feature embedding.
|
69 |
* **Prototype Computation**: The prototype for each class is computed as the mean of the embeddings of support samples belonging to that class.
|
70 |
* **Distance-Based Classification**: Query samples are classified based on the distance (using `torch.cdist`) to the nearest prototype.
|
71 |
-
* **Optimization**: The network is trained to minimize the distance between query samples and their correct prototypes while maximizing the distance to incorrect ones.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
`torch` `torchvision` `tqdm`
|
2 |
|
3 |
This implementation is inspired by **"Prototypical Networks for Few-Shot Learning" (Snell et al., 2017)**.
|
|
|
9 |
|
10 |
Few-shot learning aims to enable models to generalize to new classes with only a few labeled examples. Prototypical Networks achieve this by computing a prototype (mean embedding) for each class and classifying query samples based on their distances to these prototypes in the embedding space.
|
11 |
|
12 |
+
> You can access the full documentation here: [gitbook](https://lif31up.gitbook.io/lif31up/meta-learning/prototypical-networks-for-few-shot-learning)
|
13 |
+
|
14 |
+
> You can access the test result on colab here: [colab](https://colab.research.google.com/drive/1gsVtGvISCpXQZsKvFjLVocn89ovazusE?usp=sharing)
|
15 |
|
16 |
## Instruction
|
17 |
Organize your dataset into a structure compatible with PyTorch's ImageFolder:
|
|
|
33 |
```
|
34 |
python run.py train --dataset_path path/to/your/dataset --save_to /path/to/save/model --n_way 5 --k_shot 2 --n_query 4 --epochs 1 --iters 4
|
35 |
```
|
36 |
+
* `dataset`: Path to your dataset.
|
37 |
* `save_to`: path to save the trained model.
|
38 |
* `n_way`: number of classes in each episode.
|
39 |
* `k_shot`: Number of support samples per class.
|
40 |
* `n-_query`: Number of query samples per class.
|
41 |
+
|
42 |
+
> change training configuration from `config.py`
|
43 |
|
44 |
### Evaluation
|
45 |
```
|
46 |
python run.py --path path/to/your/dataset --model path/to/saved/model.pth --n_way 5
|
47 |
```
|
48 |
+
* `dataset`: Path to your dataset.
|
49 |
* `model`: Path to your model.
|
|
|
50 |
|
51 |
### Download Omniglot Dataset
|
52 |
```
|
|
|
61 |
* **Embedding Representation with CNN**: Each input image is passed through a convolutional encoder to obtain a feature embedding.
|
62 |
* **Prototype Computation**: The prototype for each class is computed as the mean of the embeddings of support samples belonging to that class.
|
63 |
* **Distance-Based Classification**: Query samples are classified based on the distance (using `torch.cdist`) to the nearest prototype.
|
64 |
+
* **Optimization**: The network is trained to minimize the distance between query samples and their correct prototypes while maximizing the distance to incorrect ones.
|
config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
HYPERPARAMETER_CONFIG = {
|
2 |
+
"lr": 0.001,
|
3 |
+
"weight_decay": 0.0001
|
4 |
+
} # HYPERPARAMETER_CONFIG
|
5 |
+
TRAINING_CONFIG = {
|
6 |
+
"iters": 10,
|
7 |
+
"epochs": 10,
|
8 |
+
} # TRAINING_CONFIG
|
run.py
CHANGED
@@ -1,33 +1,30 @@
|
|
1 |
import argparse
|
2 |
-
import src.train as train
|
3 |
-
import src.eval as eval
|
4 |
import torchvision as tv
|
|
|
|
|
5 |
|
6 |
def main():
|
7 |
# eval(default)
|
8 |
parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
|
9 |
-
parser.add_argument("--path", type=str, help="path of your model")
|
10 |
parser.add_argument("--model", type=str, help="path of your model")
|
11 |
-
parser.add_argument("--
|
12 |
|
13 |
# train
|
14 |
subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
|
15 |
parser_train = subparser.add_parser("train", help="train your model")
|
16 |
-
parser_train.add_argument("--
|
17 |
parser_train.add_argument("--save_to", type=str, help="path to save your model")
|
18 |
parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
|
19 |
parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
|
20 |
parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
|
21 |
parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
|
22 |
parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
|
23 |
-
parser_train.set_defaults(func=lambda kwargs: train
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
iters=kwargs.iters,
|
30 |
-
epochs=kwargs.epochs)
|
31 |
) # parser_train.set_defaults()
|
32 |
|
33 |
# download dataset
|
@@ -35,10 +32,10 @@ def main():
|
|
35 |
parser_download.add_argument("--path", type=str, help="path to download dataset")
|
36 |
parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
|
37 |
|
|
|
38 |
args = parser.parse_args()
|
39 |
if hasattr(args, 'func'): args.func(args)
|
40 |
-
|
41 |
-
else: print("invalid argument. exiting program.")
|
42 |
# main():
|
43 |
|
44 |
if __name__ == "__main__": main()
|
|
|
1 |
import argparse
|
|
|
|
|
2 |
import torchvision as tv
|
3 |
+
from src.train import train
|
4 |
+
from src.evaluate import evaluate
|
5 |
|
6 |
def main():
|
7 |
# eval(default)
|
8 |
parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
|
|
|
9 |
parser.add_argument("--model", type=str, help="path of your model")
|
10 |
+
parser.add_argument("--dataset", type=str, help="path of your dataset")
|
11 |
|
12 |
# train
|
13 |
subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
|
14 |
parser_train = subparser.add_parser("train", help="train your model")
|
15 |
+
parser_train.add_argument("--dataset", type=str, help="path to your dataset")
|
16 |
parser_train.add_argument("--save_to", type=str, help="path to save your model")
|
17 |
parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
|
18 |
parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
|
19 |
parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
|
20 |
parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
|
21 |
parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
|
22 |
+
parser_train.set_defaults(func=lambda kwargs: train(
|
23 |
+
DATASET=kwargs.dataset,
|
24 |
+
SAVE_TO=kwargs.save_to,
|
25 |
+
N_WAY=kwargs.n_way,
|
26 |
+
K_SHOT=kwargs.k_shot,
|
27 |
+
N_QUERY=kwargs.n_query)
|
|
|
|
|
28 |
) # parser_train.set_defaults()
|
29 |
|
30 |
# download dataset
|
|
|
32 |
parser_download.add_argument("--path", type=str, help="path to download dataset")
|
33 |
parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
|
34 |
|
35 |
+
# parse logic
|
36 |
args = parser.parse_args()
|
37 |
if hasattr(args, 'func'): args.func(args)
|
38 |
+
else: evaluate(MODEL=args.model, DATASET=args.dataset)
|
|
|
39 |
# main():
|
40 |
|
41 |
if __name__ == "__main__": main()
|
src/FewShotEpisoder.py
CHANGED
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|
7 |
class FewShotDataset(Dataset):
|
8 |
""" A custom Dataset class for Few-Shot Learning tasks.
|
9 |
This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
|
10 |
-
def __init__(self, dataset
|
11 |
""" Args:
|
12 |
dataset (list): List of (feature, label) pairs.
|
13 |
indices (list): List of indices to be used for the dataset.
|
@@ -59,6 +59,8 @@ class FewShotEpisoder:
|
|
59 |
indices_c = {label: [] for label in range(len(self.classes))}
|
60 |
for index, (_, label) in enumerate(self.dataset):
|
61 |
if label in self.classes: indices_c[label].append(index)
|
|
|
|
|
62 |
return indices_c
|
63 |
# get_indices():
|
64 |
|
|
|
7 |
class FewShotDataset(Dataset):
|
8 |
""" A custom Dataset class for Few-Shot Learning tasks.
|
9 |
This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
|
10 |
+
def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"):
|
11 |
""" Args:
|
12 |
dataset (list): List of (feature, label) pairs.
|
13 |
indices (list): List of indices to be used for the dataset.
|
|
|
59 |
indices_c = {label: [] for label in range(len(self.classes))}
|
60 |
for index, (_, label) in enumerate(self.dataset):
|
61 |
if label in self.classes: indices_c[label].append(index)
|
62 |
+
for label, _indices_c in indices_c.items():
|
63 |
+
indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query)
|
64 |
return indices_c
|
65 |
# get_indices():
|
66 |
|
src/{eval.py → evaluate.py}
RENAMED
@@ -1,25 +1,25 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
from torch.utils.data import DataLoader
|
4 |
-
from src.FewShotEpisoder import FewShotEpisoder
|
5 |
-
from src.model.ProtoNet import ProtoNet
|
6 |
import torchvision as tv
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
# load model
|
12 |
-
|
13 |
-
|
14 |
-
transform = data["transform"]
|
15 |
-
model = ProtoNet().to(device)
|
16 |
-
model.load_state_dict(state)
|
17 |
model.eval()
|
18 |
|
19 |
# create FSL episode generator
|
20 |
-
imageset = tv.datasets.ImageFolder(root=
|
21 |
-
|
22 |
-
episoder = FewShotEpisoder(imageset, chosen_classes, 2, 1, transform)
|
23 |
|
24 |
# compute prototype from support examples
|
25 |
support_set, query_set = episoder.get_episode()
|
@@ -40,7 +40,7 @@ def main(model: str, path: str, n_way=5):
|
|
40 |
loss = criterion(pred, label)
|
41 |
total_loss += loss.item()
|
42 |
if torch.argmax(pred) == torch.argmax(label): count += 1
|
43 |
-
print(f"
|
44 |
# main()
|
45 |
|
46 |
-
if __name__ == "__main__":
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
from torch.utils.data import DataLoader
|
|
|
|
|
4 |
import torchvision as tv
|
5 |
+
from src.model.ProtoNet import ProtoNet
|
6 |
+
from src.FewShotEpisoder import FewShotEpisoder
|
7 |
+
|
8 |
+
def evaluate(MODEL: str, DATASET: str):
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # select device
|
10 |
|
11 |
+
# load model
|
12 |
+
data = torch.load(MODEL)
|
13 |
+
n_way, k_shot, n_query = data["framework"]
|
14 |
|
15 |
# load model
|
16 |
+
model = ProtoNet(*data["model_config"].values()).to(device)
|
17 |
+
model.load_state_dict(data["state"])
|
|
|
|
|
|
|
18 |
model.eval()
|
19 |
|
20 |
# create FSL episode generator
|
21 |
+
imageset = tv.datasets.ImageFolder(root=DATASET)
|
22 |
+
episoder = FewShotEpisoder(imageset, data["chosen_classes"], k_shot, n_query, data["transform"])
|
|
|
23 |
|
24 |
# compute prototype from support examples
|
25 |
support_set, query_set = episoder.get_episode()
|
|
|
40 |
loss = criterion(pred, label)
|
41 |
total_loss += loss.item()
|
42 |
if torch.argmax(pred) == torch.argmax(label): count += 1
|
43 |
+
print(f"accuracy: {count / n_problem:.4f}({count}/{n_problem})")
|
44 |
# main()
|
45 |
|
46 |
+
if __name__ == "__main__": evaluate("./model/model.pth", "../data/omniglot-py/images_background/Futurama")
|
src/model/ProtoNet.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
-
import torch.nn.functional as
|
4 |
|
5 |
class ProtoNet(nn.Module):
|
6 |
-
def __init__(self, in_channels=3, hidden_channel=26):
|
7 |
super(ProtoNet, self).__init__()
|
8 |
self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
|
9 |
self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
|
10 |
-
self.conv3 = nn.Conv2d(hidden_channel,
|
11 |
self.relu = nn.ReLU()
|
12 |
self.flatten = nn.Flatten()
|
13 |
self.softmax = nn.LogSoftmax(dim=1)
|
@@ -21,7 +21,7 @@ class ProtoNet(nn.Module):
|
|
21 |
if metric == "euclidean":
|
22 |
dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
|
23 |
elif metric == "cosine":
|
24 |
-
dists = 1 -
|
25 |
else:
|
26 |
raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
|
27 |
return dists
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
import torch.nn.functional as torch_f
|
4 |
|
5 |
class ProtoNet(nn.Module):
|
6 |
+
def __init__(self, in_channels=3, hidden_channel=26, output_channel=3):
|
7 |
super(ProtoNet, self).__init__()
|
8 |
self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
|
9 |
self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
|
10 |
+
self.conv3 = nn.Conv2d(hidden_channel, output_channel, kernel_size=3, stride=1, padding=1)
|
11 |
self.relu = nn.ReLU()
|
12 |
self.flatten = nn.Flatten()
|
13 |
self.softmax = nn.LogSoftmax(dim=1)
|
|
|
21 |
if metric == "euclidean":
|
22 |
dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
|
23 |
elif metric == "cosine":
|
24 |
+
dists = 1 - torch_f.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
|
25 |
else:
|
26 |
raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
|
27 |
return dists
|
src/train.py
CHANGED
@@ -1,32 +1,37 @@
|
|
1 |
import torch.cuda
|
2 |
import torchvision as tv
|
3 |
from torch import nn
|
|
|
4 |
from torch.utils.data import DataLoader
|
5 |
from src.FewShotEpisoder import FewShotEpisoder
|
6 |
from src.model.ProtoNet import ProtoNet
|
7 |
-
from
|
8 |
|
9 |
-
def
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device
|
11 |
|
12 |
-
#
|
13 |
transform = tv.transforms.Compose([
|
14 |
tv.transforms.Resize((224, 224)),
|
15 |
tv.transforms.ToTensor(),
|
16 |
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
17 |
]) # transform
|
18 |
-
imageset = tv.datasets.ImageFolder(root=path)
|
19 |
-
chosen_classes = list(imageset.class_to_idx.values())[:n_way]
|
20 |
-
episoder = FewShotEpisoder(imageset, chosen_classes, k_shot, n_query, transform)
|
21 |
|
22 |
-
# init
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
criterion = nn.CrossEntropyLoss()
|
26 |
|
27 |
-
|
|
|
28 |
support_set, query_set = episoder.get_episode()
|
29 |
-
# compute prototype from support examples
|
30 |
prototypes = list()
|
31 |
embedded_features_list = [[] for _ in range(len(support_set.classes))]
|
32 |
for embedded_feature, label in support_set: embedded_features_list[label].append(embedded_feature)
|
@@ -36,23 +41,32 @@ def main(path, save_to, n_way=5, k_shot=5, n_query=2, iters=10, epochs=1):
|
|
36 |
# for
|
37 |
prototypes = torch.stack(prototypes)
|
38 |
model.prototyping(prototypes)
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
for feature, label in DataLoader(query_set, shuffle=True):
|
42 |
loss = criterion(model.forward(feature), label)
|
43 |
-
|
44 |
optim.zero_grad()
|
45 |
loss.backward()
|
46 |
optim.step()
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
# saving the model's parameters and the other data
|
51 |
features = {
|
52 |
"state": model.state_dict(),
|
|
|
53 |
"transform": transform,
|
|
|
|
|
54 |
} # features
|
55 |
-
torch.save(features,
|
|
|
56 |
# main()
|
57 |
|
58 |
-
if __name__ == "__main__":
|
|
|
1 |
import torch.cuda
|
2 |
import torchvision as tv
|
3 |
from torch import nn
|
4 |
+
from tqdm import tqdm
|
5 |
from torch.utils.data import DataLoader
|
6 |
from src.FewShotEpisoder import FewShotEpisoder
|
7 |
from src.model.ProtoNet import ProtoNet
|
8 |
+
from config import TRAINING_CONFIG, HYPERPARAMETER_CONFIG
|
9 |
|
10 |
+
def train(DATASET:str, SAVE_TO:str, N_WAY:int, K_SHOT:int, N_QUERY:int, IETRS=TRAINING_CONFIG["iters"], EPOCHS=TRAINING_CONFIG["epochs"]):
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device
|
12 |
|
13 |
+
# define transform
|
14 |
transform = tv.transforms.Compose([
|
15 |
tv.transforms.Resize((224, 224)),
|
16 |
tv.transforms.ToTensor(),
|
17 |
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
18 |
]) # transform
|
|
|
|
|
|
|
19 |
|
20 |
+
# init episode generator
|
21 |
+
imageset = tv.datasets.ImageFolder(root=DATASET)
|
22 |
+
chosen_classes = list(imageset.class_to_idx.values())[:N_WAY]
|
23 |
+
episoder = FewShotEpisoder(imageset, chosen_classes, K_SHOT, N_QUERY, transform)
|
24 |
+
|
25 |
+
# init model
|
26 |
+
model_config = {"in_channels": 3, "hidden_channels": 26, "output_channels": 3}
|
27 |
+
model = ProtoNet(*model_config.values()).to(device)
|
28 |
+
optim = torch.optim.Adam(model.parameters(), lr=HYPERPARAMETER_CONFIG["lr"], weight_decay=HYPERPARAMETER_CONFIG["weight_decay"])
|
29 |
criterion = nn.CrossEntropyLoss()
|
30 |
|
31 |
+
progress_bar, whole_loss = tqdm(range(EPOCHS)), float()
|
32 |
+
for _ in progress_bar:
|
33 |
support_set, query_set = episoder.get_episode()
|
34 |
+
# STAGE1: compute prototype from support examples
|
35 |
prototypes = list()
|
36 |
embedded_features_list = [[] for _ in range(len(support_set.classes))]
|
37 |
for embedded_feature, label in support_set: embedded_features_list[label].append(embedded_feature)
|
|
|
41 |
# for
|
42 |
prototypes = torch.stack(prototypes)
|
43 |
model.prototyping(prototypes)
|
44 |
+
# STAGE2: update parameters form loss associated with prototypes
|
45 |
+
epochs_loss = 0.0
|
46 |
+
for _ in range(IETRS):
|
47 |
+
iter_loss = 0.0
|
48 |
for feature, label in DataLoader(query_set, shuffle=True):
|
49 |
loss = criterion(model.forward(feature), label)
|
50 |
+
iter_loss += loss.item()
|
51 |
optim.zero_grad()
|
52 |
loss.backward()
|
53 |
optim.step()
|
54 |
+
epochs_loss += iter_loss / len(query_set)
|
55 |
+
# for # for
|
56 |
+
epochs_loss = epochs_loss / IETRS
|
57 |
+
progress_bar.set_postfix(loss=epochs_loss)
|
58 |
+
# for
|
59 |
|
60 |
# saving the model's parameters and the other data
|
61 |
features = {
|
62 |
"state": model.state_dict(),
|
63 |
+
"model_config": model_config,
|
64 |
"transform": transform,
|
65 |
+
"chosen_classes": chosen_classes,
|
66 |
+
"framework": (N_WAY, K_SHOT, N_QUERY)
|
67 |
} # features
|
68 |
+
torch.save(features, SAVE_TO)
|
69 |
+
print(f"model save to {SAVE_TO}")
|
70 |
# main()
|
71 |
|
72 |
+
if __name__ == "__main__": train("../data/omniglot-py/images_background/Futurama", "./model/model.pth", 5, 5, 2, 5, 5)
|