Spaces:
Runtime error
Runtime error
File size: 4,622 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def mask_matrix_nms(masks,
labels,
scores,
filter_thr=-1,
nms_pre=-1,
max_num=-1,
kernel='gaussian',
sigma=2.0,
mask_area=None):
"""Matrix NMS for multi-class masks.
Args:
masks (Tensor): Has shape (num_instances, h, w)
labels (Tensor): Labels of corresponding masks,
has shape (num_instances,).
scores (Tensor): Mask scores of corresponding masks,
has shape (num_instances).
filter_thr (float): Score threshold to filter the masks
after matrix nms. Default: -1, which means do not
use filter_thr.
nms_pre (int): The max number of instances to do the matrix nms.
Default: -1, which means do not use nms_pre.
max_num (int, optional): If there are more than max_num masks after
matrix, only top max_num will be kept. Default: -1, which means
do not use max_num.
kernel (str): 'linear' or 'gaussian'.
sigma (float): std in gaussian method.
mask_area (Tensor): The sum of seg_masks.
Returns:
tuple(Tensor): Processed mask results.
- scores (Tensor): Updated scores, has shape (n,).
- labels (Tensor): Remained labels, has shape (n,).
- masks (Tensor): Remained masks, has shape (n, w, h).
- keep_inds (Tensor): The indices number of
the remaining mask in the input mask, has shape (n,).
"""
assert len(labels) == len(masks) == len(scores)
if len(labels) == 0:
return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
0, *masks.shape[-2:]), labels.new_zeros(0)
if mask_area is None:
mask_area = masks.sum((1, 2)).float()
else:
assert len(masks) == len(mask_area)
# sort and keep top nms_pre
scores, sort_inds = torch.sort(scores, descending=True)
keep_inds = sort_inds
if nms_pre > 0 and len(sort_inds) > nms_pre:
sort_inds = sort_inds[:nms_pre]
keep_inds = keep_inds[:nms_pre]
scores = scores[:nms_pre]
masks = masks[sort_inds]
mask_area = mask_area[sort_inds]
labels = labels[sort_inds]
num_masks = len(labels)
flatten_masks = masks.reshape(num_masks, -1).float()
# inter.
inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
expanded_mask_area = mask_area.expand(num_masks, num_masks)
# Upper triangle iou matrix.
iou_matrix = (inter_matrix /
(expanded_mask_area + expanded_mask_area.transpose(1, 0) -
inter_matrix)).triu(diagonal=1)
# label_specific matrix.
expanded_labels = labels.expand(num_masks, num_masks)
# Upper triangle label matrix.
label_matrix = (expanded_labels == expanded_labels.transpose(
1, 0)).triu(diagonal=1)
# IoU compensation
compensate_iou, _ = (iou_matrix * label_matrix).max(0)
compensate_iou = compensate_iou.expand(num_masks,
num_masks).transpose(1, 0)
# IoU decay
decay_iou = iou_matrix * label_matrix
# Calculate the decay_coefficient
if kernel == 'gaussian':
decay_matrix = torch.exp(-1 * sigma * (decay_iou**2))
compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2))
decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
elif kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient, _ = decay_matrix.min(0)
else:
raise NotImplementedError(
f'{kernel} kernel is not supported in matrix nms!')
# update the score.
scores = scores * decay_coefficient
if filter_thr > 0:
keep = scores >= filter_thr
keep_inds = keep_inds[keep]
if not keep.any():
return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
0, *masks.shape[-2:]), labels.new_zeros(0)
masks = masks[keep]
scores = scores[keep]
labels = labels[keep]
# sort and keep top max_num
scores, sort_inds = torch.sort(scores, descending=True)
keep_inds = keep_inds[sort_inds]
if max_num > 0 and len(sort_inds) > max_num:
sort_inds = sort_inds[:max_num]
keep_inds = keep_inds[:max_num]
scores = scores[:max_num]
masks = masks[sort_inds]
labels = labels[sort_inds]
return scores, labels, masks, keep_inds
|