|
import torch |
|
|
|
def postprocess(out, depth_mode, conf_mode,classifier_mode=None): |
|
""" |
|
extract 3D points/confidence from prediction head output |
|
""" |
|
fmap = out.permute(0, 2, 3, 1) |
|
if classifier_mode is None: |
|
if fmap.shape[-1] == 4: |
|
res = dict(pointmaps=reg_dense_pts3d(fmap[:, :, :, :-1], mode=depth_mode)) |
|
else: |
|
res = dict(depths=reg_dense_depth(fmap[:, :, :, 0], mode=depth_mode)) |
|
if conf_mode is not None: |
|
res['conf_pointmaps'] = reg_dense_conf(fmap[:, :, :, -1], mode=conf_mode) |
|
else: |
|
res = dict(classifier=reg_dense_classifier(fmap[:, :, :, 0], mode=classifier_mode)) |
|
if conf_mode is not None: |
|
res['conf_classifier'] = reg_dense_conf(fmap[:, :, :, 1], mode=conf_mode) |
|
|
|
return res |
|
|
|
def reg_dense_classifier(x, mode): |
|
""" |
|
extract classifier from prediction head output |
|
""" |
|
mode, vmin, vmax = mode |
|
|
|
return x |
|
|
|
def reg_dense_depth(x, mode): |
|
""" |
|
extract depth from prediction head output |
|
""" |
|
mode, vmin, vmax = mode |
|
no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) |
|
assert no_bounds |
|
if mode == 'linear': |
|
return x |
|
elif mode == 'square': |
|
return x.square().clip(min=vmin, max=vmax) |
|
elif mode == 'exp': |
|
return torch.exp(x).clip(min=vmin, max=vmax) |
|
else: |
|
raise ValueError(f'bad {mode=}') |
|
|
|
def reg_dense_pts3d(xyz, mode): |
|
""" |
|
extract 3D points from prediction head output |
|
""" |
|
mode, vmin, vmax = mode |
|
|
|
no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) |
|
assert no_bounds |
|
|
|
if mode == 'linear': |
|
if no_bounds: |
|
return xyz |
|
return xyz.clip(min=vmin, max=vmax) |
|
|
|
|
|
d = xyz.norm(dim=-1, keepdim=True) |
|
xyz = xyz / d.clip(min=1e-8) |
|
if mode == 'square': |
|
return xyz * d.square() |
|
|
|
if mode == 'exp': |
|
return xyz * torch.expm1(d) |
|
raise ValueError(f'bad {mode=}') |
|
|
|
def reg_dense_conf(x, mode): |
|
""" |
|
extract confidence from prediction head output |
|
""" |
|
mode, vmin, vmax = mode |
|
if mode == 'exp': |
|
return vmin + x.exp().clip(max=vmax-vmin) |
|
if mode == 'sigmoid': |
|
return (vmax - vmin) * torch.sigmoid(x) + vmin |
|
raise ValueError(f'bad {mode=}') |
|
|
|
|