RockeyCoss
add code files”
51f6859
raw
history blame
6.54 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
from mmcv.runner import get_dist_info
from torch.utils.data import Sampler
from mmdet.core.utils import sync_random_seed
class ClassAwareSampler(Sampler):
r"""Sampler that restricts data loading to the label of the dataset.
A class-aware sampling strategy to effectively tackle the
non-uniform class distribution. The length of the training data is
consistent with source data. Simple improvements based on `Relay
Backpropagation for Effective Learning of Deep Convolutional
Neural Networks <https://arxiv.org/abs/1512.05830>`_
The implementation logic is referred to
https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
Args:
dataset: Dataset used for sampling.
samples_per_gpu (int): When model is :obj:`DistributedDataParallel`,
it is the number of training samples on each GPU.
When model is :obj:`DataParallel`, it is
`num_gpus * samples_per_gpu`.
Default : 1.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
seed (int, optional): random seed used to shuffle the sampler if
``shuffle=True``. This number should be identical across all
processes in the distributed group. Default: 0.
num_sample_class (int): The number of samples taken from each
per-label list. Default: 1
"""
def __init__(self,
dataset,
samples_per_gpu=1,
num_replicas=None,
rank=None,
seed=0,
num_sample_class=1):
_rank, _num_replicas = get_dist_info()
if num_replicas is None:
num_replicas = _num_replicas
if rank is None:
rank = _rank
self.dataset = dataset
self.num_replicas = num_replicas
self.samples_per_gpu = samples_per_gpu
self.rank = rank
self.epoch = 0
# Must be the same across all workers. If None, will use a
# random seed shared among workers
# (require synchronization among all workers)
self.seed = sync_random_seed(seed)
# The number of samples taken from each per-label list
assert num_sample_class > 0 and isinstance(num_sample_class, int)
self.num_sample_class = num_sample_class
# Get per-label image list from dataset
assert hasattr(dataset, 'get_cat2imgs'), \
'dataset must have `get_cat2imgs` function'
self.cat_dict = dataset.get_cat2imgs()
self.num_samples = int(
math.ceil(
len(self.dataset) * 1.0 / self.num_replicas /
self.samples_per_gpu)) * self.samples_per_gpu
self.total_size = self.num_samples * self.num_replicas
# get number of images containing each category
self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
# filter labels without images
self.valid_cat_inds = [
i for i, length in enumerate(self.num_cat_imgs) if length != 0
]
self.num_classes = len(self.valid_cat_inds)
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
# initialize label list
label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
# initialize each per-label image list
data_iter_dict = dict()
for i in self.valid_cat_inds:
data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
"""Traverse the categories and extract `num_sample_cls` image
indexes of the corresponding categories one by one."""
id_indices = []
for _ in range(len(cls_list)):
cls_idx = next(cls_list)
for _ in range(num_sample_cls):
id = next(data_dict[cls_idx])
id_indices.append(id)
return id_indices
# deterministically shuffle based on epoch
num_bins = int(
math.ceil(self.total_size * 1.0 / self.num_classes /
self.num_sample_class))
indices = []
for i in range(num_bins):
indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
self.num_sample_class)
# fix extra samples to make it evenly divisible
if len(indices) >= self.total_size:
indices = indices[:self.total_size]
else:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset:offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class RandomCycleIter:
"""Shuffle the list and do it again after the list have traversed.
The implementation logic is referred to
https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
Example:
>>> label_list = [0, 1, 2, 4, 5]
>>> g = torch.Generator()
>>> g.manual_seed(0)
>>> label_iter_list = RandomCycleIter(label_list, generator=g)
>>> index = next(label_iter_list)
Args:
data (list or ndarray): The data that needs to be shuffled.
generator: An torch.Generator object, which is used in setting the seed
for generating random numbers.
""" # noqa: W605
def __init__(self, data, generator=None):
self.data = data
self.length = len(data)
self.index = torch.randperm(self.length, generator=generator).numpy()
self.i = 0
self.generator = generator
def __iter__(self):
return self
def __len__(self):
return len(self.data)
def __next__(self):
if self.i == self.length:
self.index = torch.randperm(
self.length, generator=self.generator).numpy()
self.i = 0
idx = self.data[self.index[self.i]]
self.i += 1
return idx