lif31up commited on
Commit
9d58bb5
·
1 Parent(s): 2b6add1

init: safetensor, bin

Browse files
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sate": "model.state_dict()",
3
+ "FRAMEWORK": [5, 5, 2],
4
+ "MODEL_CONFIG": [3, 26, 3],
5
+ "HYPER_PARAMETERS": {"lr": 0.0001, "weight_decay": 0.0001},
6
+ "TRANSFORM": "transform",
7
+ "METRIC": "euclidean"
8
+ }
config.py DELETED
@@ -1,8 +0,0 @@
1
- HYPERPARAMETER_CONFIG = {
2
- "lr": 0.0001,
3
- "weight_decay": 0.0001
4
- } # HYPERPARAMETER_CONFIG
5
- TRAINING_CONFIG = {
6
- "iters": 10,
7
- "epochs": 30,
8
- } # TRAINING_CONFIG
 
 
 
 
 
 
 
 
 
src/model/ProtoNet.py → model.py RENAMED
@@ -1,13 +1,10 @@
1
- import torch
2
- from torch import nn
3
- import torch.nn.functional as torch_f
4
-
5
  class ProtoNet(nn.Module):
6
- def __init__(self, in_channels=3, hidden_channel=26, output_channel=3):
7
  super(ProtoNet, self).__init__()
 
8
  self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
9
  self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
10
- self.conv3 = nn.Conv2d(hidden_channel, output_channel, kernel_size=3, stride=1, padding=1)
11
  self.relu = nn.ReLU()
12
  self.flatten = nn.Flatten()
13
  self.softmax = nn.LogSoftmax(dim=1)
@@ -16,18 +13,19 @@ class ProtoNet(nn.Module):
16
  def prototyping(self, prototypes): self.prototypes = prototypes
17
 
18
  def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor:
 
19
  assert self.prototypes is not None, "Prototypes must be set before calling cdist."
20
  assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match."
21
  if metric == "euclidean":
22
  dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
23
  elif metric == "cosine":
24
- dists = 1 - torch_f.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
25
  else:
26
  raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
27
  return dists
28
  # cdist()
29
 
30
- def forward(self, x):
31
  x = self.conv1(x)
32
  x = self.relu(x)
33
  x = self.conv2(x)
@@ -35,7 +33,7 @@ class ProtoNet(nn.Module):
35
  x = self.conv3(x)
36
  x = self.relu(x)
37
  x = self.flatten(x)
38
- x = self.cdist(x, metric="euclidean")
39
  return self.softmax(-x)
40
  # forward
41
  # ProtoNet
 
 
 
 
 
1
  class ProtoNet(nn.Module):
2
+ def __init__(self, in_channels, hidden_channel):
3
  super(ProtoNet, self).__init__()
4
+ # 임베딩 네트워크
5
  self.conv1 = nn.Conv2d(in_channels, hidden_channel, kernel_size=3, stride=1, padding=1)
6
  self.conv2 = nn.Conv2d(hidden_channel, hidden_channel, kernel_size=3, stride=1, padding=1)
7
+ self.conv3 = nn.Conv2d(hidden_channel, in_channels, kernel_size=3, stride=1, padding=1)
8
  self.relu = nn.ReLU()
9
  self.flatten = nn.Flatten()
10
  self.softmax = nn.LogSoftmax(dim=1)
 
13
  def prototyping(self, prototypes): self.prototypes = prototypes
14
 
15
  def cdist(self, x: torch.Tensor, metric="euclidean") -> torch.Tensor:
16
+ # 거리 함수
17
  assert self.prototypes is not None, "Prototypes must be set before calling cdist."
18
  assert x.size(1) == self.prototypes.size(1), "Feature dimensions must match."
19
  if metric == "euclidean":
20
  dists = torch.cdist(x, self.prototypes, p=2) # L2 distance
21
  elif metric == "cosine":
22
+ dists = 1 - F.cosine_similarity(x.unsqueeze(1), self.prototypes.unsqueeze(0), dim=2) # 1 - cosine similarity
23
  else:
24
  raise ValueError("Unsupported distance metric. Choose 'euclidean' or 'cosine'.")
25
  return dists
26
  # cdist()
27
 
28
+ def forward(self, x, metric="euclidean"):
29
  x = self.conv1(x)
30
  x = self.relu(x)
31
  x = self.conv2(x)
 
33
  x = self.conv3(x)
34
  x = self.relu(x)
35
  x = self.flatten(x)
36
+ x = self.cdist(x, metric=metric)
37
  return self.softmax(-x)
38
  # forward
39
  # ProtoNet
prototypical_network.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1207dad08600ea1a3f1a622d2469a3208ebc8dd007c3b6051b256d7df5103f03
3
+ size 34442
prototypical_network.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe7d84d01a2cbd6369d78422355940f7f902eb78e9359f764d8ebc7c77eacad8
3
+ size 30620
run.py DELETED
@@ -1,41 +0,0 @@
1
- import argparse
2
- import torchvision as tv
3
- from src.train import train
4
- from src.evaluate import evaluate
5
-
6
- def main():
7
- # eval(default)
8
- parser = argparse.ArgumentParser(description="Few-shot learning using Prototypical Network")
9
- parser.add_argument("--model", type=str, help="path of your model")
10
- parser.add_argument("--dataset", type=str, help="path of your dataset")
11
-
12
- # train
13
- subparser = parser.add_subparsers(title="subcommands", dest="subcommand")
14
- parser_train = subparser.add_parser("train", help="train your model")
15
- parser_train.add_argument("--dataset", type=str, help="path to your dataset")
16
- parser_train.add_argument("--save_to", type=str, help="path to save your model")
17
- parser_train.add_argument("--n_way", type=int, help="number of classes per episode")
18
- parser_train.add_argument("--k_shot", type=int, help="number of support samples per class")
19
- parser_train.add_argument("--n_query", type=int, help="number of query samples per class")
20
- parser_train.add_argument("--iters", type=int, help="how much iteration your model does for an episode")
21
- parser_train.add_argument("--epochs", type=int, help="how much epochs your model does for training")
22
- parser_train.set_defaults(func=lambda kwargs: train(
23
- DATASET=kwargs.dataset,
24
- SAVE_TO=kwargs.save_to,
25
- N_WAY=kwargs.n_way,
26
- K_SHOT=kwargs.k_shot,
27
- N_QUERY=kwargs.n_query)
28
- ) # parser_train.set_defaults()
29
-
30
- # download dataset
31
- parser_download = subparser.add_parser("download", help="download dataset")
32
- parser_download.add_argument("--path", type=str, help="path to download dataset")
33
- parser_download.set_defaults(func=lambda kwargs: tv.datasets.Omniglot(root=kwargs.path, background=True, download=True))
34
-
35
- # parse logic
36
- args = parser.parse_args()
37
- if hasattr(args, 'func'): args.func(args)
38
- else: evaluate(MODEL=args.model, DATASET=args.dataset)
39
- # main():
40
-
41
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/FewShotEpisoder.py DELETED
@@ -1,84 +0,0 @@
1
- import random
2
- import typing
3
- import torch
4
- from torch.utils.data import Dataset
5
- import torch.nn.functional as F
6
-
7
- class FewShotDataset(Dataset):
8
- """ A custom Dataset class for Few-Shot Learning tasks.
9
- This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
10
- def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"):
11
- """ Args:
12
- dataset (list): List of (feature, label) pairs.
13
- indices (list): List of indices to be used for the dataset.
14
- transform (callable): Transform to be applied to the features.
15
- mode (str): Mode of operation, either "support" or "query". Default is "support". """
16
- assert mode in ["support", "query"], "Invalid mode. Must be either 'support' or 'query'." # check if mode is valid
17
- assert dataset and indices and classes is not None, "Dataset or indices cannot be None." # check if dataset is not None
18
-
19
- self.dataset, self.indices, self.classes = dataset, indices, classes
20
- self.mode, self.transform = mode, transform
21
- # __init__():
22
-
23
- def __getitem__(self, index: int):
24
- """ Returns a sample from the dataset at the given index.
25
- Args: index of the sample to be retrieved.
26
- Returns: tuple of the transformed feature and the label. """
27
- if index >= len(self.indices):
28
- raise IndexError("Index out of bounds") # check if index is out of bounds
29
- feature, label = self.dataset[self.indices[index]]
30
- # apply transformation
31
- feature = self.transform(feature)
32
- if self.mode == "query": # if mode is query, convert label to one-hot vector
33
- label = F.one_hot(torch.tensor(self.classes.index(label)), num_classes=len(self.classes)).float()
34
- return feature, label
35
- # __getitem__():
36
-
37
- def __len__(self): return len(self.indices)
38
- # FSLDataset()
39
-
40
- class FewShotEpisoder:
41
- """ A class to generate episodes for Few-Shot Learning.
42
- Each episode consists of a support set and a query set. """
43
- def __init__(self, dataset, classes: list, k_shot: int, n_query: int, transform: typing.Callable):
44
- """ Args:
45
- dataset (Dataset): The base dataset to generate episodes from.
46
- k_shot (int): Number of support samples per class.
47
- n_query (int): Number of query samples per class.
48
- transform (callable): Transform to be applied to the features. """
49
- assert k_shot > 0 and n_query > 0, "k_shot and n_query must be greater than 0." # check if k_shot and n_query are valid
50
-
51
- self.k_shot, self.n_query, self.classes = k_shot, n_query, classes
52
- self.dataset, self.transform = dataset, transform
53
- self.indices_c = self.get_class_indices()
54
- # __init__()
55
-
56
- def get_class_indices(self) -> dict:
57
- """ Initialize the class indices for the dataset.
58
- Returns: tuple of Number of classes and a list of indices grouped by class. """
59
- indices_c = {label: [] for label in range(self.classes.__len__())}
60
- for index, (_, label) in enumerate(self.dataset):
61
- if label in self.classes: indices_c[self.classes.index(label)].append(index)
62
- for label, _indices_c in indices_c.items():
63
- indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query)
64
- return indices_c
65
- # get_indices():
66
-
67
- def get_episode(self) -> tuple: # select classes using list of chosen indexes
68
- """ Generate an episode consisting of a support set and a query set.
69
- Returns: tuple of A FewShotDataset for the support set and a FewShotDataset for the query set. """
70
- # get support and query examples
71
- support_examples, query_examples = [], []
72
- for class_label in range(self.classes.__len__()):
73
- if len(self.indices_c[class_label]) < self.k_shot + self.n_query: continue # skip class if it doesn't have enough samples
74
- selected_indices = random.sample(self.indices_c[class_label], self.k_shot + self.n_query)
75
- support_examples.extend(selected_indices[:self.k_shot])
76
- query_examples.extend(selected_indices)
77
-
78
- # init support and query datasets
79
- support_set = FewShotDataset(self.dataset, support_examples, self.classes, self.transform, "support")
80
- query_set = FewShotDataset(self.dataset, query_examples, self.classes, self.transform, "query")
81
-
82
- return support_set, query_set
83
- # get_episode()
84
- # Episoder()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/evaluate.py DELETED
@@ -1,48 +0,0 @@
1
- import random
2
- import torch
3
- from torch import nn
4
- from torch.utils.data import DataLoader
5
- import torchvision as tv
6
- from src.model.ProtoNet import ProtoNet
7
- from src.FewShotEpisoder import FewShotEpisoder
8
-
9
- def evaluate(MODEL: str, DATASET: str):
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # select device
11
-
12
- # load model
13
- data = torch.load(MODEL)
14
- n_way, k_shot, n_query = data["framework"]
15
-
16
- # load model
17
- model = ProtoNet(*data["model_config"].values()).to(device)
18
- model.load_state_dict(data["state"])
19
- model.eval()
20
-
21
- # create FSL episode generator
22
- imageset = tv.datasets.ImageFolder(root=DATASET)
23
- unseen_classes = [_ for _ in random.sample(list(imageset.class_to_idx.values()), n_way)]
24
- episoder = FewShotEpisoder(imageset, unseen_classes, k_shot, n_query, data["transform"])
25
-
26
- # compute prototype from support examples
27
- support_set, query_set = episoder.get_episode()
28
- prototypes = list()
29
- embedded_features_list = [[] for _ in range(len(support_set.classes))]
30
- for embedded_feature, label in support_set: embedded_features_list[unseen_classes.index(label)].append(embedded_feature)
31
- for embedded_features in embedded_features_list:
32
- class_prototype = torch.stack(embedded_features).mean(dim=0)
33
- prototypes.append(class_prototype.flatten())
34
- prototypes = torch.stack(prototypes)
35
- model.prototyping(prototypes)
36
-
37
- # eval model
38
- total_loss, count, n_problem = 0., 0, len(query_set)
39
- criterion = nn.CrossEntropyLoss()
40
- for feature, label in DataLoader(query_set, shuffle=True):
41
- pred = model.forward(feature)
42
- loss = criterion(pred, label)
43
- total_loss += loss.item()
44
- if torch.argmax(pred) == torch.argmax(label): count += 1
45
- print(f"seen classes: {data['seen_classes']}\nunseen classes: {unseen_classes}\naccuracy: {count / n_problem:.4f}({count}/{n_problem})")
46
- # main()
47
-
48
- if __name__ == "__main__": evaluate("./model/model.pth", "../data/omniglot-py/images_background/Futurama")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/train.py DELETED
@@ -1,73 +0,0 @@
1
- import random
2
- import torch.cuda
3
- import torchvision as tv
4
- from torch import nn
5
- from tqdm import tqdm
6
- from torch.utils.data import DataLoader
7
- from src.FewShotEpisoder import FewShotEpisoder
8
- from src.model.ProtoNet import ProtoNet
9
- from config import TRAINING_CONFIG, HYPERPARAMETER_CONFIG
10
-
11
- def train(DATASET:str, SAVE_TO:str, N_WAY:int, K_SHOT:int, N_QUERY:int, ITERS=TRAINING_CONFIG["iters"], EPOCHS=TRAINING_CONFIG["epochs"]):
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # init device
13
-
14
- # define transform
15
- transform = tv.transforms.Compose([
16
- tv.transforms.Resize((224, 224)),
17
- tv.transforms.ToTensor(),
18
- tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
- ]) # transform
20
-
21
- # init episode generator
22
- imageset = tv.datasets.ImageFolder(root=DATASET)
23
- seen_classes = [_ for _ in random.sample(list(imageset.class_to_idx.values()), N_WAY)]
24
- episoder = FewShotEpisoder(imageset, seen_classes, K_SHOT, N_QUERY, transform)
25
-
26
- # init model
27
- model_config = {"in_channels": 3, "hidden_channels": 26, "output_channels": 3}
28
- model = ProtoNet(*model_config.values()).to(device)
29
- optim = torch.optim.Adam(model.parameters(), lr=HYPERPARAMETER_CONFIG["lr"], weight_decay=HYPERPARAMETER_CONFIG["weight_decay"])
30
- criterion = nn.CrossEntropyLoss()
31
-
32
- progress_bar, whole_loss = tqdm(range(EPOCHS)), float()
33
- support_set, query_set = episoder.get_episode()
34
- for _ in progress_bar:
35
- # STAGE1: compute prototype from support examples
36
- prototypes = list()
37
- embedded_features_list = [[] for _ in range(len(support_set.classes))]
38
- for embedded_feature, label in support_set: embedded_features_list[seen_classes.index(label)].append(embedded_feature)
39
- for embedded_features in embedded_features_list:
40
- class_prototype = torch.stack(embedded_features).mean(dim=0)
41
- prototypes.append(class_prototype.flatten())
42
- # for
43
- prototypes = torch.stack(prototypes)
44
- model.prototyping(prototypes)
45
- # STAGE2: update parameters form loss associated with prototypes
46
- epochs_loss = 0.0
47
- for _ in range(ITERS):
48
- iter_loss = 0.0
49
- for feature, label in DataLoader(query_set, shuffle=True):
50
- loss = criterion(model.forward(feature), label)
51
- iter_loss += loss.item()
52
- optim.zero_grad()
53
- loss.backward()
54
- optim.step()
55
- epochs_loss += iter_loss / len(query_set)
56
- # for # for
57
- epochs_loss = epochs_loss / ITERS
58
- progress_bar.set_postfix(loss=epochs_loss)
59
- # for
60
-
61
- # saving the model's parameters and the other data
62
- features = {
63
- "state": model.state_dict(),
64
- "model_config": model_config,
65
- "transform": transform,
66
- "seen_classes": seen_classes,
67
- "framework": (N_WAY, K_SHOT, N_QUERY)
68
- } # features
69
- torch.save(features, SAVE_TO)
70
- print(f"model save to {SAVE_TO}")
71
- # main()
72
-
73
- if __name__ == "__main__": train("../data/omniglot-py/images_background/Futurama", "./model/model.pth", 5, 5, 2)