add: FewShotEpisoder and get_prototypes
Browse files- FewShotEpisoder.py +84 -0
- config.json +1 -0
FewShotEpisoder.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import typing
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
class FewShotDataset(Dataset):
|
8 |
+
""" A custom Dataset class for Few-Shot Learning tasks.
|
9 |
+
This dataset can operate in two modes: "support" (for prototype calculation) and "query" (for evaluation). """
|
10 |
+
def __init__(self, dataset, indices: list, classes: list, transform:typing.Callable, mode="support"):
|
11 |
+
""" Args:
|
12 |
+
dataset (list): List of (feature, label) pairs.
|
13 |
+
indices (list): List of indices to be used for the dataset.
|
14 |
+
transform (callable): Transform to be applied to the features.
|
15 |
+
mode (str): Mode of operation, either "support" or "query". Default is "support". """
|
16 |
+
assert mode in ["support", "query"], "Invalid mode. Must be either 'support' or 'query'." # check if mode is valid
|
17 |
+
assert dataset and indices and classes is not None, "Dataset or indices cannot be None." # check if dataset is not None
|
18 |
+
|
19 |
+
self.dataset, self.indices, self.classes = dataset, indices, classes
|
20 |
+
self.mode, self.transform = mode, transform
|
21 |
+
# __init__():
|
22 |
+
|
23 |
+
def __getitem__(self, index: int):
|
24 |
+
""" Returns a sample from the dataset at the given index.
|
25 |
+
Args: index of the sample to be retrieved.
|
26 |
+
Returns: tuple of the transformed feature and the label. """
|
27 |
+
if index >= len(self.indices):
|
28 |
+
raise IndexError("Index out of bounds") # check if index is out of bounds
|
29 |
+
feature, label = self.dataset[self.indices[index]]
|
30 |
+
# apply transformation
|
31 |
+
feature = self.transform(feature)
|
32 |
+
if self.mode == "query": # if mode is query, convert label to one-hot vector
|
33 |
+
label = F.one_hot(torch.tensor(self.classes.index(label)), num_classes=len(self.classes)).float()
|
34 |
+
return feature, label
|
35 |
+
# __getitem__():
|
36 |
+
|
37 |
+
def __len__(self): return len(self.indices)
|
38 |
+
# FSLDataset()
|
39 |
+
|
40 |
+
class FewShotEpisoder:
|
41 |
+
""" A class to generate episodes for Few-Shot Learning.
|
42 |
+
Each episode consists of a support set and a query set. """
|
43 |
+
def __init__(self, dataset, classes: list, k_shot: int, n_query: int, transform: typing.Callable):
|
44 |
+
""" Args:
|
45 |
+
dataset (Dataset): The base dataset to generate episodes from.
|
46 |
+
k_shot (int): Number of support samples per class.
|
47 |
+
n_query (int): Number of query samples per class.
|
48 |
+
transform (callable): Transform to be applied to the features. """
|
49 |
+
assert k_shot > 0 and n_query > 0, "k_shot and n_query must be greater than 0." # check if k_shot and n_query are valid
|
50 |
+
|
51 |
+
self.k_shot, self.n_query, self.classes = k_shot, n_query, classes
|
52 |
+
self.dataset, self.transform = dataset, transform
|
53 |
+
self.indices_c = self.get_class_indices()
|
54 |
+
# __init__()
|
55 |
+
|
56 |
+
def get_class_indices(self) -> dict:
|
57 |
+
""" Initialize the class indices for the dataset.
|
58 |
+
Returns: tuple of Number of classes and a list of indices grouped by class. """
|
59 |
+
indices_c = {label: [] for label in range(self.classes.__len__())}
|
60 |
+
for index, (_, label) in enumerate(self.dataset):
|
61 |
+
if label in self.classes: indices_c[self.classes.index(label)].append(index)
|
62 |
+
for label, _indices_c in indices_c.items():
|
63 |
+
indices_c[label] = random.sample(_indices_c, self.k_shot + self.n_query)
|
64 |
+
return indices_c
|
65 |
+
# get_indices():
|
66 |
+
|
67 |
+
def get_episode(self) -> tuple: # select classes using list of chosen indexes
|
68 |
+
""" Generate an episode consisting of a support set and a query set.
|
69 |
+
Returns: tuple of A FewShotDataset for the support set and a FewShotDataset for the query set. """
|
70 |
+
# get support and query examples
|
71 |
+
support_examples, query_examples = [], []
|
72 |
+
for class_label in range(self.classes.__len__()):
|
73 |
+
if len(self.indices_c[class_label]) < self.k_shot + self.n_query: continue # skip class if it doesn't have enough samples
|
74 |
+
selected_indices = random.sample(self.indices_c[class_label], self.k_shot + self.n_query)
|
75 |
+
support_examples.extend(selected_indices[:self.k_shot])
|
76 |
+
query_examples.extend(selected_indices)
|
77 |
+
|
78 |
+
# init support and query datasets
|
79 |
+
support_set = FewShotDataset(self.dataset, support_examples, self.classes, self.transform, "support")
|
80 |
+
query_set = FewShotDataset(self.dataset, query_examples, self.classes, self.transform, "query")
|
81 |
+
|
82 |
+
return support_set, query_set
|
83 |
+
# get_episode()
|
84 |
+
# Episoder()
|
config.json
CHANGED
@@ -4,5 +4,6 @@
|
|
4 |
"MODEL_CONFIG": [3, 26, 3],
|
5 |
"HYPER_PARAMETERS": {"lr": 0.0001, "weight_decay": 0.0001},
|
6 |
"TRANSFORM": "transform",
|
|
|
7 |
"METRIC": "euclidean"
|
8 |
}
|
|
|
4 |
"MODEL_CONFIG": [3, 26, 3],
|
5 |
"HYPER_PARAMETERS": {"lr": 0.0001, "weight_decay": 0.0001},
|
6 |
"TRANSFORM": "transform",
|
7 |
+
"GET_PROTOTYPES": "get_prototypes(support_set)",
|
8 |
"METRIC": "euclidean"
|
9 |
}
|