lif31up commited on
Commit
918349d
·
1 Parent(s): 14fa675

add: FewShotEpisoder and get_prototypes

Browse files
Files changed (2) hide show
  1. FewShotEpisoder.py +84 -0
  2. config.json +1 -0
FewShotEpisoder.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import typing
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+
7
+ class FewShotDataset(Dataset):
8
+ """ A custom Dataset class for Few-Shot Learning tasks.
9
+ This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
10
+ def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"):
11
+ """ Args:
12
+ dataset (list): List of (feature, label) pairs.
13
+ indices (list): List of indices to be used for the dataset.
14
+ transform (callable): Transform to be applied to the features.
15
+ mode (str): Mode of operation, either "support" or "query". Default is "support". """
16
+ assert mode in ["support", "query"], "Invalid mode. Must be either 'support' or 'query'." # check if mode is valid
17
+ assert dataset and indices and classes is not None, "Dataset or indices cannot be None." # check if dataset is not None
18
+
19
+ self.dataset, self.indices, self.classes = dataset, indices, classes
20
+ self.mode, self.transform = mode, transform
21
+ # __init__():
22
+
23
+ def __getitem__(self, index: int):
24
+ """ Returns a sample from the dataset at the given index.
25
+ Args: index of the sample to be retrieved.
26
+ Returns: tuple of the transformed feature and the label. """
27
+ if index >= len(self.indices):
28
+ raise IndexError("Index out of bounds") # check if index is out of bounds
29
+ feature, label = self.dataset[self.indices[index]]
30
+ # apply transformation
31
+ feature = self.transform(feature)
32
+ if self.mode == "query": # if mode is query, convert label to one-hot vector
33
+ label = F.one_hot(torch.tensor(self.classes.index(label)), num_classes=len(self.classes)).float()
34
+ return feature, label
35
+ # __getitem__():
36
+
37
+ def __len__(self): return len(self.indices)
38
+ # FSLDataset()
39
+
40
+ class FewShotEpisoder:
41
+ """ A class to generate episodes for Few-Shot Learning.
42
+ Each episode consists of a support set and a query set. """
43
+ def __init__(self, dataset, classes: list, k_shot: int, n_query: int, transform: typing.Callable):
44
+ """ Args:
45
+ dataset (Dataset): The base dataset to generate episodes from.
46
+ k_shot (int): Number of support samples per class.
47
+ n_query (int): Number of query samples per class.
48
+ transform (callable): Transform to be applied to the features. """
49
+ assert k_shot > 0 and n_query > 0, "k_shot and n_query must be greater than 0." # check if k_shot and n_query are valid
50
+
51
+ self.k_shot, self.n_query, self.classes = k_shot, n_query, classes
52
+ self.dataset, self.transform = dataset, transform
53
+ self.indices_c = self.get_class_indices()
54
+ # __init__()
55
+
56
+ def get_class_indices(self) -> dict:
57
+ """ Initialize the class indices for the dataset.
58
+ Returns: tuple of Number of classes and a list of indices grouped by class. """
59
+ indices_c = {label: [] for label in range(self.classes.__len__())}
60
+ for index, (_, label) in enumerate(self.dataset):
61
+ if label in self.classes: indices_c[self.classes.index(label)].append(index)
62
+ for label, _indices_c in indices_c.items():
63
+ indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query)
64
+ return indices_c
65
+ # get_indices():
66
+
67
+ def get_episode(self) -> tuple: # select classes using list of chosen indexes
68
+ """ Generate an episode consisting of a support set and a query set.
69
+ Returns: tuple of A FewShotDataset for the support set and a FewShotDataset for the query set. """
70
+ # get support and query examples
71
+ support_examples, query_examples = [], []
72
+ for class_label in range(self.classes.__len__()):
73
+ if len(self.indices_c[class_label]) < self.k_shot + self.n_query: continue # skip class if it doesn't have enough samples
74
+ selected_indices = random.sample(self.indices_c[class_label], self.k_shot + self.n_query)
75
+ support_examples.extend(selected_indices[:self.k_shot])
76
+ query_examples.extend(selected_indices)
77
+
78
+ # init support and query datasets
79
+ support_set = FewShotDataset(self.dataset, support_examples, self.classes, self.transform, "support")
80
+ query_set = FewShotDataset(self.dataset, query_examples, self.classes, self.transform, "query")
81
+
82
+ return support_set, query_set
83
+ # get_episode()
84
+ # Episoder()
config.json CHANGED
@@ -4,5 +4,6 @@
4
  "MODEL_CONFIG": [3, 26, 3],
5
  "HYPER_PARAMETERS": {"lr": 0.0001, "weight_decay": 0.0001},
6
  "TRANSFORM": "transform",
 
7
  "METRIC": "euclidean"
8
  }
 
4
  "MODEL_CONFIG": [3, 26, 3],
5
  "HYPER_PARAMETERS": {"lr": 0.0001, "weight_decay": 0.0001},
6
  "TRANSFORM": "transform",
7
+ "GET_PROTOTYPES": "get_prototypes(support_set)",
8
  "METRIC": "euclidean"
9
  }