rayst3r / models /heads /postprocess.py
bartduis's picture
init
70d1188
raw
history blame
2.36 kB
import torch
def postprocess(out, depth_mode, conf_mode,classifier_mode=None):
"""
extract 3D points/confidence from prediction head output
"""
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
if classifier_mode is None:
if fmap.shape[-1] == 4:
res = dict(pointmaps=reg_dense_pts3d(fmap[:, :, :, :-1], mode=depth_mode))
else:
res = dict(depths=reg_dense_depth(fmap[:, :, :, 0], mode=depth_mode))
if conf_mode is not None:
res['conf_pointmaps'] = reg_dense_conf(fmap[:, :, :, -1], mode=conf_mode)
else:
res = dict(classifier=reg_dense_classifier(fmap[:, :, :, 0], mode=classifier_mode))
if conf_mode is not None:
res['conf_classifier'] = reg_dense_conf(fmap[:, :, :, 1], mode=conf_mode)
return res
def reg_dense_classifier(x, mode):
"""
extract classifier from prediction head output
"""
mode, vmin, vmax = mode
#return torch.sigmoid(x)
return x
def reg_dense_depth(x, mode):
"""
extract depth from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
assert no_bounds
if mode == 'linear':
return x
elif mode == 'square':
return x.square().clip(min=vmin, max=vmax)
elif mode == 'exp':
return torch.exp(x).clip(min=vmin, max=vmax)
else:
raise ValueError(f'bad {mode=}')
def reg_dense_pts3d(xyz, mode):
"""
extract 3D points from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
assert no_bounds
if mode == 'linear':
if no_bounds:
return xyz # [-inf, +inf]
return xyz.clip(min=vmin, max=vmax)
# distance to origin
d = xyz.norm(dim=-1, keepdim=True)
xyz = xyz / d.clip(min=1e-8)
if mode == 'square':
return xyz * d.square()
if mode == 'exp':
return xyz * torch.expm1(d)
raise ValueError(f'bad {mode=}')
def reg_dense_conf(x, mode):
"""
extract confidence from prediction head output
"""
mode, vmin, vmax = mode
if mode == 'exp':
return vmin + x.exp().clip(max=vmax-vmin)
if mode == 'sigmoid':
return (vmax - vmin) * torch.sigmoid(x) + vmin
raise ValueError(f'bad {mode=}')