Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import torch | |
| from detectron2.layers import nonzero_tuple | |
| __all__ = ["subsample_labels"] | |
| def subsample_labels( | |
| labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int | |
| ): | |
| """ | |
| Return `num_samples` (or fewer, if not enough found) | |
| random samples from `labels` which is a mixture of positives & negatives. | |
| It will try to return as many positives as possible without | |
| exceeding `positive_fraction * num_samples`, and then try to | |
| fill the remaining slots with negatives. | |
| Args: | |
| labels (Tensor): (N, ) label vector with values: | |
| * -1: ignore | |
| * bg_label: background ("negative") class | |
| * otherwise: one or more foreground ("positive") classes | |
| num_samples (int): The total number of labels with value >= 0 to return. | |
| Values that are not sampled will be filled with -1 (ignore). | |
| positive_fraction (float): The number of subsampled labels with values > 0 | |
| is `min(num_positives, int(positive_fraction * num_samples))`. The number | |
| of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. | |
| In order words, if there are not enough positives, the sample is filled with | |
| negatives. If there are also not enough negatives, then as many elements are | |
| sampled as is possible. | |
| bg_label (int): label index of background ("negative") class. | |
| Returns: | |
| pos_idx, neg_idx (Tensor): | |
| 1D vector of indices. The total length of both is `num_samples` or fewer. | |
| """ | |
| positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] | |
| negative = nonzero_tuple(labels == bg_label)[0] | |
| num_pos = int(num_samples * positive_fraction) | |
| # protect against not enough positive examples | |
| num_pos = min(positive.numel(), num_pos) | |
| num_neg = num_samples - num_pos | |
| # protect against not enough negative examples | |
| num_neg = min(negative.numel(), num_neg) | |
| # randomly select positive and negative examples | |
| perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] | |
| perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] | |
| pos_idx = positive[perm1] | |
| neg_idx = negative[perm2] | |
| return pos_idx, neg_idx | |