lif31up commited on
Commit
c4c047d
·
1 Parent(s): 71d5391

big update

Browse files
Files changed (9) hide show
  1. .gitignore +2 -1
  2. LICENSE +21 -0
  3. README.md +8 -15
  4. config.py +8 -0
  5. run.py +12 -15
  6. src/FewShotEpisoder.py +3 -1
  7. src/{eval.py → evaluate.py} +14 -14
  8. src/model/ProtoNet.py +4 -4
  9. src/train.py +32 -18
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
  # user-defined
2
- data/raw
3
  *.pth
 
4
 
5
  # pycharm
6
  .idea
 
1
  # user-defined
2
+ **/data
3
  *.pth
4
+ **/__pycache__/
5
 
6
  # pycharm
7
  .idea
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 한명환
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,3 @@
1
- ---
2
- license: mit
3
- datasets:
4
- - dpdl-benchmark/omniglot
5
- language:
6
- - ko
7
- pipeline_tag: image-classification
8
- ---
9
  `torch` `torchvision` `tqdm`
10
 
11
  This implementation is inspired by **"Prototypical Networks for Few-Shot Learning" (Snell et al., 2017)**.
@@ -17,7 +9,9 @@ This repository implements a Prototypical Network for few-shot image classificat
17
 
18
  Few-shot learning aims to enable models to generalize to new classes with only a few labeled examples. Prototypical Networks achieve this by computing a prototype (mean embedding) for each class and classifying query samples based on their distances to these prototypes in the embedding space.
19
 
20
- <a href="https://colab.research.google.com/drive/1gsVtGvISCpXQZsKvFjLVocn89ovazusE?usp=sharing">Test Result on Colab</a>
 
 
21
 
22
  ## Instruction
23
  Organize your dataset into a structure compatible with PyTorch's ImageFolder:
@@ -39,21 +33,20 @@ Run the training script with desired parameters:
39
  ```
40
  python run.py train --dataset_path path/to/your/dataset --save_to /path/to/save/model --n_way 5 --k_shot 2 --n_query 4 --epochs 1 --iters 4
41
  ```
42
- * `dataset_path`: Path to your dataset.
43
  * `save_to`: path to save the trained model.
44
  * `n_way`: number of classes in each episode.
45
  * `k_shot`: Number of support samples per class.
46
  * `n-_query`: Number of query samples per class.
47
- * `epochs`: Number of episodes.
48
- * `iters`: Number of training epochs.
49
 
50
  ### Evaluation
51
  ```
52
  python run.py --path path/to/your/dataset --model path/to/saved/model.pth --n_way 5
53
  ```
54
- * `path`: Path to your dataset.
55
  * `model`: Path to your model.
56
- * `n_way`: Number of classes in each episode.
57
 
58
  ### Download Omniglot Dataset
59
  ```
@@ -68,4 +61,4 @@ Prototypical Networks are a powerful approach for **few-shot learning**, where t
68
  * **Embedding Representation with CNN**: Each input image is passed through a convolutional encoder to obtain a feature embedding.
69
  * **Prototype Computation**: The prototype for each class is computed as the mean of the embeddings of support samples belonging to that class.
70
  * **Distance-Based Classification**: Query samples are classified based on the distance (using `torch.cdist`) to the nearest prototype.
71
- * **Optimization**: The network is trained to minimize the distance between query samples and their correct prototypes while maximizing the distance to incorrect ones.
 
 
 
 
 
 
 
 
 
1
  `torch` `torchvision` `tqdm`
2
 
3
  This implementation is inspired by **"Prototypical Networks for Few-Shot Learning" (Snell et al., 2017)**.
 
9
 
10
  Few-shot learning aims to enable models to generalize to new classes with only a few labeled examples. Prototypical Networks achieve this by computing a prototype (mean embedding) for each class and classifying query samples based on their distances to these prototypes in the embedding space.
11
 
12
+ > You can access the full documentation here: [gitbook](https://lif31up.gitbook.io/lif31up/meta-learning/prototypical-networks-for-few-shot-learning)
13
+
14
+ > You can access the test result on colab here: [colab](https://colab.research.google.com/drive/1gsVtGvISCpXQZsKvFjLVocn89ovazusE?usp=sharing)
15
 
16
  ## Instruction
17
  Organize your dataset into a structure compatible with PyTorch's ImageFolder:
 
33
  ```
34
  python run.py train --dataset_path path/to/your/dataset --save_to /path/to/save/model --n_way 5 --k_shot 2 --n_query 4 --epochs 1 --iters 4
35
  ```
36
+ * `dataset`: Path to your dataset.
37
  * `save_to`: path to save the trained model.
38
  * `n_way`: number of classes in each episode.
39
  * `k_shot`: Number of support samples per class.
40
  * `n-_query`: Number of query samples per class.
41
+
42
+ > change training configuration from `config.py`
43
 
44
  ### Evaluation
45
  ```
46
  python run.py --path path/to/your/dataset --model path/to/saved/model.pth --n_way 5
47
  ```
48
+ * `dataset`: Path to your dataset.
49
  * `model`: Path to your model.
 
50
 
51
  ### Download Omniglot Dataset
52
  ```
 
61
  * **Embedding Representation with CNN**: Each input image is passed through a convolutional encoder to obtain a feature embedding.
62
  * **Prototype Computation**: The prototype for each class is computed as the mean of the embeddings of support samples belonging to that class.
63
  * **Distance-Based Classification**: Query samples are classified based on the distance (using `torch.cdist`) to the nearest prototype.
64
+ * **Optimization**: The network is trained to minimize the distance between query samples and their correct prototypes while maximizing the distance to incorrect ones.
config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ HYPERPARAMETER_CONFIG = {
2
+ "lr": 0.001,
3
+ "weight_decay": 0.0001
4
+ } # HYPERPARAMETER_CONFIG
5
+ TRAINING_CONFIG = {
6
+ "iters": 10,
7
+ "epochs": 10,
8
+ } # TRAINING_CONFIG
run.py CHANGED
@@ -1,33 +1,30 @@
1
  import argparse
2
- import src.train as train
3
- import src.eval as eval
4
  import torchvision as tv
 
 
5
 
6
  def main():
7
  # eval(default)
8
  parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
9
- parser.add_argument("--path", type=str, help="path of your model")
10
  parser.add_argument("--model", type=str, help="path of your model")
11
- parser.add_argument("--n_way", type=int, help="number of classes per episode")
12
 
13
  # train
14
  subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
15
  parser_train = subparser.add_parser("train", help="train your model")
16
- parser_train.add_argument("--path", type=str, help="path to your dataset")
17
  parser_train.add_argument("--save_to", type=str, help="path to save your model")
18
  parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
19
  parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
20
  parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
21
  parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
22
  parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
23
- parser_train.set_defaults(func=lambda kwargs: train.main(
24
- path=kwargs.dataset_path,
25
- save_to=kwargs.save_to,
26
- n_way=kwargs.n_way,
27
- k_shot=kwargs.k_shot,
28
- n_query=kwargs.n_query,
29
- iters=kwargs.iters,
30
- epochs=kwargs.epochs)
31
  ) # parser_train.set_defaults()
32
 
33
  # download dataset
@@ -35,10 +32,10 @@ def main():
35
  parser_download.add_argument("--path", type=str, help="path to download dataset")
36
  parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
37
 
 
38
  args = parser.parse_args()
39
  if hasattr(args, 'func'): args.func(args)
40
- elif args.path and args.model: eval.main(model=args.path, path=args.model, n_way=args.n_way)
41
- else: print("invalid argument. exiting program.")
42
  # main():
43
 
44
  if __name__ == "__main__": main()
 
1
  import argparse
 
 
2
  import torchvision as tv
3
+ from src.train import train
4
+ from src.evaluate import evaluate
5
 
6
  def main():
7
  # eval(default)
8
  parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
 
9
  parser.add_argument("--model", type=str, help="path of your model")
10
+ parser.add_argument("--dataset", type=str, help="path of your dataset")
11
 
12
  # train
13
  subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
14
  parser_train = subparser.add_parser("train", help="train your model")
15
+ parser_train.add_argument("--dataset", type=str, help="path to your dataset")
16
  parser_train.add_argument("--save_to", type=str, help="path to save your model")
17
  parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
18
  parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
19
  parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
20
  parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
21
  parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
22
+ parser_train.set_defaults(func=lambda kwargs: train(
23
+ DATASET=kwargs.dataset,
24
+ SAVE_TO=kwargs.save_to,
25
+ N_WAY=kwargs.n_way,
26
+ K_SHOT=kwargs.k_shot,
27
+ N_QUERY=kwargs.n_query)
 
 
28
  ) # parser_train.set_defaults()
29
 
30
  # download dataset
 
32
  parser_download.add_argument("--path", type=str, help="path to download dataset")
33
  parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
34
 
35
+ # parse logic
36
  args = parser.parse_args()
37
  if hasattr(args, 'func'): args.func(args)
38
+ else: evaluate(MODEL=args.model, DATASET=args.dataset)
 
39
  # main():
40
 
41
  if __name__ == "__main__": main()
src/FewShotEpisoder.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn.functional as F
7
  class FewShotDataset(Dataset):
8
  """ A custom Dataset class for Few-Shot Learning tasks.
9
  This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
10
- def __init__(self, dataset: typing.Iterable, indices: list, classes: list, transform:typing.Callable, mode="support"):
11
  """ Args:
12
  dataset (list): List of (feature, label) pairs.
13
  indices (list): List of indices to be used for the dataset.
@@ -59,6 +59,8 @@ class FewShotEpisoder:
59
  indices_c = {label: [] for label in range(len(self.classes))}
60
  for index, (_, label) in enumerate(self.dataset):
61
  if label in self.classes: indices_c[label].append(index)
 
 
62
  return indices_c
63
  # get_indices():
64
 
 
7
  class FewShotDataset(Dataset):
8
  """ A custom Dataset class for Few-Shot Learning tasks.
9
  This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
10
+ def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"):
11
  """ Args:
12
  dataset (list): List of (feature, label) pairs.
13
  indices (list): List of indices to be used for the dataset.
 
59
  indices_c = {label: [] for label in range(len(self.classes))}
60
  for index, (_, label) in enumerate(self.dataset):
61
  if label in self.classes: indices_c[label].append(index)
62
+ for label, _indices_c in indices_c.items():
63
+ indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query)
64
  return indices_c
65
  # get_indices():
66
 
src/{eval.py → evaluate.py} RENAMED
@@ -1,25 +1,25 @@
1
  import torch
2
  from torch import nn
3
  from torch.utils.data import DataLoader
4
- from src.FewShotEpisoder import FewShotEpisoder
5
- from src.model.ProtoNet import ProtoNet
6
  import torchvision as tv
 
 
 
 
 
7
 
8
- def main(model: str, path: str, n_way=5):
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
10
 
11
  # load model
12
- data = torch.load(model)
13
- state = data["state"]
14
- transform = data["transform"]
15
- model = ProtoNet().to(device)
16
- model.load_state_dict(state)
17
  model.eval()
18
 
19
  # create FSL episode generator
20
- imageset = tv.datasets.ImageFolder(root=path)
21
- chosen_classes = list(imageset.class_to_idx.values())[:n_way]
22
- episoder = FewShotEpisoder(imageset, chosen_classes, 2, 1, transform)
23
 
24
  # compute prototype from support examples
25
  support_set, query_set = episoder.get_episode()
@@ -40,7 +40,7 @@ def main(model: str, path: str, n_way=5):
40
  loss = criterion(pred, label)
41
  total_loss += loss.item()
42
  if torch.argmax(pred) == torch.argmax(label): count += 1
43
- print(f"loss: {total_loss / len(query_set):.4f} accuracy: {count / n_problem:.4f}({count}/{n_problem})")
44
  # main()
45
 
46
- if __name__ == "__main__": main("./model/model.pth", "../data/raw/omniglot-py/images_background/Futurama")
 
1
  import torch
2
  from torch import nn
3
  from torch.utils.data import DataLoader
 
 
4
  import torchvision as tv
5
+ from src.model.ProtoNet import ProtoNet
6
+ from src.FewShotEpisoder import FewShotEpisoder
7
+
8
+ def evaluate(MODEL: str, DATASET: str):
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # select device
10
 
11
+ # load model
12
+ data = torch.load(MODEL)
13
+ n_way, k_shot, n_query = data["framework"]
14
 
15
  # load model
16
+ model = ProtoNet(*data["model_config"].values()).to(device)
17
+ model.load_state_dict(data["state"])
 
 
 
18
  model.eval()
19
 
20
  # create FSL episode generator
21
+ imageset = tv.datasets.ImageFolder(root=DATASET)
22
+ episoder = FewShotEpisoder(imageset, data["chosen_classes"], k_shot, n_query, data["transform"])
 
23
 
24
  # compute prototype from support examples
25
  support_set, query_set = episoder.get_episode()
 
40
  loss = criterion(pred, label)
41
  total_loss += loss.item()
42
  if torch.argmax(pred) == torch.argmax(label): count += 1
43
+ print(f"accuracy: {count / n_problem:.4f}({count}/{n_problem})")
44
  # main()
45
 
46
+ if __name__ == "__main__": evaluate("./model/model.pth", "../data/omniglot-py/images_background/Futurama")
src/model/ProtoNet.py CHANGED
@@ -1,13 +1,13 @@
1
  import torch
2
  from torch import nn
3
- import torch.nn.functional as F
4
 
5
  class ProtoNet(nn.Module):
6
- def __init__(self, in_channels=3, hidden_channel=26):
7
  super(ProtoNet, self).__init__()
8
  self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
9
  self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
10
- self.conv3 = nn.Conv2d(hidden_channel, in_channels, kernel_size=3, stride=1, padding=1)
11
  self.relu = nn.ReLU()
12
  self.flatten = nn.Flatten()
13
  self.softmax = nn.LogSoftmax(dim=1)
@@ -21,7 +21,7 @@ class ProtoNet(nn.Module):
21
  if metric == "euclidean":
22
  dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
23
  elif metric == "cosine":
24
- dists = 1 - F.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
25
  else:
26
  raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
27
  return dists
 
1
  import torch
2
  from torch import nn
3
+ import torch.nn.functional as torch_f
4
 
5
  class ProtoNet(nn.Module):
6
+ def __init__(self, in_channels=3, hidden_channel=26, output_channel=3):
7
  super(ProtoNet, self).__init__()
8
  self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
9
  self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
10
+ self.conv3 = nn.Conv2d(hidden_channel, output_channel, kernel_size=3, stride=1, padding=1)
11
  self.relu = nn.ReLU()
12
  self.flatten = nn.Flatten()
13
  self.softmax = nn.LogSoftmax(dim=1)
 
21
  if metric == "euclidean":
22
  dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
23
  elif metric == "cosine":
24
+ dists = 1 - torch_f.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
25
  else:
26
  raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
27
  return dists
src/train.py CHANGED
@@ -1,32 +1,37 @@
1
  import torch.cuda
2
  import torchvision as tv
3
  from torch import nn
 
4
  from torch.utils.data import DataLoader
5
  from src.FewShotEpisoder import FewShotEpisoder
6
  from src.model.ProtoNet import ProtoNet
7
- from tqdm import tqdm
8
 
9
- def main(path, save_to, n_way=5, k_shot=5, n_query=2, iters=10, epochs=1):
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device
11
 
12
- # create FSL episode generator
13
  transform = tv.transforms.Compose([
14
  tv.transforms.Resize((224, 224)),
15
  tv.transforms.ToTensor(),
16
  tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
  ]) # transform
18
- imageset = tv.datasets.ImageFolder(root=path)
19
- chosen_classes = list(imageset.class_to_idx.values())[:n_way]
20
- episoder = FewShotEpisoder(imageset, chosen_classes, k_shot, n_query, transform)
21
 
22
- # init learning
23
- model = ProtoNet(3).to(device)
24
- optim = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
 
 
 
 
 
 
25
  criterion = nn.CrossEntropyLoss()
26
 
27
- for _ in tqdm(range(epochs), desc="epochs/episodes"):
 
28
  support_set, query_set = episoder.get_episode()
29
- # compute prototype from support examples
30
  prototypes = list()
31
  embedded_features_list = [[] for _ in range(len(support_set.classes))]
32
  for embedded_feature, label in support_set: embedded_features_list[label].append(embedded_feature)
@@ -36,23 +41,32 @@ def main(path, save_to, n_way=5, k_shot=5, n_query=2, iters=10, epochs=1):
36
  # for
37
  prototypes = torch.stack(prototypes)
38
  model.prototyping(prototypes)
39
- for _ in tqdm(range(iters), desc="\titerations/queries"):
40
- total_loss = 0.0
 
 
41
  for feature, label in DataLoader(query_set, shuffle=True):
42
  loss = criterion(model.forward(feature), label)
43
- total_loss += loss.item()
44
  optim.zero_grad()
45
  loss.backward()
46
  optim.step()
47
- print(f"loss: {total_loss / len(query_set):.4f}")
48
- # for for for
 
 
 
49
 
50
  # saving the model's parameters and the other data
51
  features = {
52
  "state": model.state_dict(),
 
53
  "transform": transform,
 
 
54
  } # features
55
- torch.save(features, save_to)
 
56
  # main()
57
 
58
- if __name__ == "__main__": main(path="../data/raw/omniglot-py/images_background/Futurama", save_to="./model/model.pth")
 
1
  import torch.cuda
2
  import torchvision as tv
3
  from torch import nn
4
+ from tqdm import tqdm
5
  from torch.utils.data import DataLoader
6
  from src.FewShotEpisoder import FewShotEpisoder
7
  from src.model.ProtoNet import ProtoNet
8
+ from config import TRAINING_CONFIG, HYPERPARAMETER_CONFIG
9
 
10
+ def train(DATASET:str, SAVE_TO:str, N_WAY:int, K_SHOT:int, N_QUERY:int, IETRS=TRAINING_CONFIG["iters"], EPOCHS=TRAINING_CONFIG["epochs"]):
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device
12
 
13
+ # define transform
14
  transform = tv.transforms.Compose([
15
  tv.transforms.Resize((224, 224)),
16
  tv.transforms.ToTensor(),
17
  tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
18
  ]) # transform
 
 
 
19
 
20
+ # init episode generator
21
+ imageset = tv.datasets.ImageFolder(root=DATASET)
22
+ chosen_classes = list(imageset.class_to_idx.values())[:N_WAY]
23
+ episoder = FewShotEpisoder(imageset, chosen_classes, K_SHOT, N_QUERY, transform)
24
+
25
+ # init model
26
+ model_config = {"in_channels": 3, "hidden_channels": 26, "output_channels": 3}
27
+ model = ProtoNet(*model_config.values()).to(device)
28
+ optim = torch.optim.Adam(model.parameters(), lr=HYPERPARAMETER_CONFIG["lr"], weight_decay=HYPERPARAMETER_CONFIG["weight_decay"])
29
  criterion = nn.CrossEntropyLoss()
30
 
31
+ progress_bar, whole_loss = tqdm(range(EPOCHS)), float()
32
+ for _ in progress_bar:
33
  support_set, query_set = episoder.get_episode()
34
+ # STAGE1: compute prototype from support examples
35
  prototypes = list()
36
  embedded_features_list = [[] for _ in range(len(support_set.classes))]
37
  for embedded_feature, label in support_set: embedded_features_list[label].append(embedded_feature)
 
41
  # for
42
  prototypes = torch.stack(prototypes)
43
  model.prototyping(prototypes)
44
+ # STAGE2: update parameters form loss associated with prototypes
45
+ epochs_loss = 0.0
46
+ for _ in range(IETRS):
47
+ iter_loss = 0.0
48
  for feature, label in DataLoader(query_set, shuffle=True):
49
  loss = criterion(model.forward(feature), label)
50
+ iter_loss += loss.item()
51
  optim.zero_grad()
52
  loss.backward()
53
  optim.step()
54
+ epochs_loss += iter_loss / len(query_set)
55
+ # for # for
56
+ epochs_loss = epochs_loss / IETRS
57
+ progress_bar.set_postfix(loss=epochs_loss)
58
+ # for
59
 
60
  # saving the model's parameters and the other data
61
  features = {
62
  "state": model.state_dict(),
63
+ "model_config": model_config,
64
  "transform": transform,
65
+ "chosen_classes": chosen_classes,
66
+ "framework": (N_WAY, K_SHOT, N_QUERY)
67
  } # features
68
+ torch.save(features, SAVE_TO)
69
+ print(f"model save to {SAVE_TO}")
70
  # main()
71
 
72
+ if __name__ == "__main__": train("../data/omniglot-py/images_background/Futurama", "./model/model.pth", 5, 5, 2, 5, 5)