import numpy as np import torch import torch.nn.functional as F from torch.utils.data.sampler import Sampler from tqdm import * class BalancedSampler(Sampler): def __init__(self, data_source, batch_size, images_per_class=3): self.data_source = data_source self.ys = np.array(data_source.all_labels) self.num_groups = batch_size // images_per_class self.batch_size = batch_size self.num_instances = images_per_class self.num_samples = len(self.ys) self.num_classes = len(set(self.ys)) def __len__(self): return self.num_samples def __iter__(self): num_batches = len(self.data_source) // self.batch_size ret = [] while num_batches > 0: sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False) for i in range(len(sampled_classes)): ith_class_idxs = np.nonzero(self.ys == sampled_classes[i])[0] class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True) ret.extend(np.random.permutation(class_sel)) num_batches -= 1 return iter(ret)