File size: 2,360 Bytes
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch

def postprocess(out, depth_mode, conf_mode,classifier_mode=None):
    """
    extract 3D points/confidence from prediction head output
    """
    fmap = out.permute(0, 2, 3, 1)  # B,H,W,3
    if classifier_mode is None:
        if fmap.shape[-1] == 4:
            res = dict(pointmaps=reg_dense_pts3d(fmap[:, :, :, :-1], mode=depth_mode))
        else:
            res = dict(depths=reg_dense_depth(fmap[:, :, :, 0], mode=depth_mode))
        if conf_mode is not None:
            res['conf_pointmaps'] = reg_dense_conf(fmap[:, :, :, -1], mode=conf_mode)
    else:
        res = dict(classifier=reg_dense_classifier(fmap[:, :, :, 0], mode=classifier_mode))
        if conf_mode is not None:
            res['conf_classifier'] = reg_dense_conf(fmap[:, :, :, 1], mode=conf_mode)
    
    return res

def reg_dense_classifier(x, mode):
    """
    extract classifier from prediction head output
    """
    mode, vmin, vmax = mode
    #return torch.sigmoid(x)
    return x

def reg_dense_depth(x, mode):
    """
    extract depth from prediction head output
    """
    mode, vmin, vmax = mode
    no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
    assert no_bounds
    if mode == 'linear':
        return x
    elif mode == 'square':
        return x.square().clip(min=vmin, max=vmax)
    elif mode == 'exp':
        return torch.exp(x).clip(min=vmin, max=vmax)    
    else:
        raise ValueError(f'bad {mode=}')

def reg_dense_pts3d(xyz, mode):
    """
    extract 3D points from prediction head output
    """
    mode, vmin, vmax = mode

    no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
    assert no_bounds

    if mode == 'linear':
        if no_bounds:
            return xyz  # [-inf, +inf]
        return xyz.clip(min=vmin, max=vmax)

    # distance to origin
    d = xyz.norm(dim=-1, keepdim=True)
    xyz = xyz / d.clip(min=1e-8)
    if mode == 'square':
        return xyz * d.square()

    if mode == 'exp':
        return xyz * torch.expm1(d)
    raise ValueError(f'bad {mode=}')

def reg_dense_conf(x, mode):
    """
    extract confidence from prediction head output
    """
    mode, vmin, vmax = mode
    if mode == 'exp':
        return vmin + x.exp().clip(max=vmax-vmin)
    if mode == 'sigmoid':
        return (vmax - vmin) * torch.sigmoid(x) + vmin
    raise ValueError(f'bad {mode=}')