Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from ..builder import BBOX_SAMPLERS | |
| from .base_sampler import BaseSampler | |
| class RandomSampler(BaseSampler): | |
| """Random sampler. | |
| Args: | |
| num (int): Number of samples | |
| pos_fraction (float): Fraction of positive samples | |
| neg_pos_ub (int, optional): Upper bound number of negative and | |
| positive samples. Defaults to -1. | |
| add_gt_as_proposals (bool, optional): Whether to add ground truth | |
| boxes as proposals. Defaults to True. | |
| """ | |
| def __init__(self, | |
| num, | |
| pos_fraction, | |
| neg_pos_ub=-1, | |
| add_gt_as_proposals=True, | |
| **kwargs): | |
| from mmdet.core.bbox import demodata | |
| super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, | |
| add_gt_as_proposals) | |
| self.rng = demodata.ensure_rng(kwargs.get('rng', None)) | |
| def random_choice(self, gallery, num): | |
| """Random select some elements from the gallery. | |
| If `gallery` is a Tensor, the returned indices will be a Tensor; | |
| If `gallery` is a ndarray or list, the returned indices will be a | |
| ndarray. | |
| Args: | |
| gallery (Tensor | ndarray | list): indices pool. | |
| num (int): expected sample num. | |
| Returns: | |
| Tensor or ndarray: sampled indices. | |
| """ | |
| assert len(gallery) >= num | |
| is_tensor = isinstance(gallery, torch.Tensor) | |
| if not is_tensor: | |
| if torch.cuda.is_available(): | |
| device = torch.cuda.current_device() | |
| else: | |
| device = 'cpu' | |
| gallery = torch.tensor(gallery, dtype=torch.long, device=device) | |
| # This is a temporary fix. We can revert the following code | |
| # when PyTorch fixes the abnormal return of torch.randperm. | |
| # See: https://github.com/open-mmlab/mmdetection/pull/5014 | |
| perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) | |
| rand_inds = gallery[perm] | |
| if not is_tensor: | |
| rand_inds = rand_inds.cpu().numpy() | |
| return rand_inds | |
| def _sample_pos(self, assign_result, num_expected, **kwargs): | |
| """Randomly sample some positive samples.""" | |
| pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) | |
| if pos_inds.numel() != 0: | |
| pos_inds = pos_inds.squeeze(1) | |
| if pos_inds.numel() <= num_expected: | |
| return pos_inds | |
| else: | |
| return self.random_choice(pos_inds, num_expected) | |
| def _sample_neg(self, assign_result, num_expected, **kwargs): | |
| """Randomly sample some negative samples.""" | |
| neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) | |
| if neg_inds.numel() != 0: | |
| neg_inds = neg_inds.squeeze(1) | |
| if len(neg_inds) <= num_expected: | |
| return neg_inds | |
| else: | |
| return self.random_choice(neg_inds, num_expected) | |