|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import torch
|
|
|
|
class DecDecoder(object):
|
|
def __init__(self, K, conf_thresh):
|
|
self.K = 17
|
|
self.conf_thresh = conf_thresh
|
|
|
|
def _topk(self, scores):
|
|
batch, cat, height, width = scores.size()
|
|
|
|
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), self.K)
|
|
|
|
topk_inds = topk_inds % (height * width)
|
|
topk_ys = (topk_inds / width).int().float()
|
|
topk_xs = (topk_inds % width).int().float()
|
|
|
|
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), self.K)
|
|
topk_inds = self._gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, self.K)
|
|
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, self.K)
|
|
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, self.K)
|
|
|
|
return topk_score, topk_inds, topk_ys, topk_xs
|
|
|
|
|
|
def _nms(self, heat, kernel=3):
|
|
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
|
|
keep = (hmax == heat).float()
|
|
return heat * keep
|
|
|
|
def _gather_feat(self, feat, ind, mask=None):
|
|
dim = feat.size(2)
|
|
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
|
|
feat = feat.gather(1, ind)
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(2).expand_as(feat)
|
|
feat = feat[mask]
|
|
feat = feat.view(-1, dim)
|
|
return feat
|
|
|
|
def _tranpose_and_gather_feat(self, feat, ind):
|
|
feat = feat.permute(0, 2, 3, 1).contiguous()
|
|
feat = feat.view(feat.size(0), -1, feat.size(3))
|
|
feat = self._gather_feat(feat, ind)
|
|
return feat
|
|
|
|
def ctdet_decode(self, heat, wh, reg):
|
|
|
|
|
|
batch, c, height, width = heat.size()
|
|
heat = self._nms(heat)
|
|
scores, inds, ys, xs = self._topk(heat)
|
|
scores = scores.view(batch, self.K, 1)
|
|
reg = self._tranpose_and_gather_feat(reg, inds)
|
|
reg = reg.view(batch, self.K, 2)
|
|
xs = xs.view(batch, self.K, 1) + reg[:, :, 0:1]
|
|
ys = ys.view(batch, self.K, 1) + reg[:, :, 1:2]
|
|
wh = self._tranpose_and_gather_feat(wh, inds)
|
|
wh = wh.view(batch, self.K, 2*4)
|
|
|
|
tl_x = xs - wh[:,:,0:1]
|
|
tl_y = ys - wh[:,:,1:2]
|
|
tr_x = xs - wh[:,:,2:3]
|
|
tr_y = ys - wh[:,:,3:4]
|
|
bl_x = xs - wh[:,:,4:5]
|
|
bl_y = ys - wh[:,:,5:6]
|
|
br_x = xs - wh[:,:,6:7]
|
|
br_y = ys - wh[:,:,7:8]
|
|
|
|
pts = torch.cat([xs, ys,
|
|
tl_x,tl_y,
|
|
tr_x,tr_y,
|
|
bl_x,bl_y,
|
|
br_x,br_y,
|
|
scores], dim=2).squeeze(0)
|
|
return pts.data.cpu().numpy() |