Ridgeformer / combined_sampler.py
spandey8's picture
Upload 11 files
007d3b9 verified
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler
from tqdm import *
class BalancedSampler(Sampler):
def __init__(self, data_source, batch_size, images_per_class=3):
self.data_source = data_source
self.ys = np.array(data_source.all_labels)
self.num_groups = batch_size // images_per_class
self.batch_size = batch_size
self.num_instances = images_per_class
self.num_samples = len(self.ys)
self.num_classes = len(set(self.ys))
def __len__(self):
return self.num_samples
def __iter__(self):
num_batches = len(self.data_source) // self.batch_size
ret = []
while num_batches > 0:
sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False)
for i in range(len(sampled_classes)):
ith_class_idxs = np.nonzero(self.ys == sampled_classes[i])[0]
class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True)
ret.extend(np.random.permutation(class_sel))
num_batches -= 1
return iter(ret)