File size: 4,460 Bytes
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm, lognorm
import torch

def colorize_points_with_turbo_all_dims(points, method='norm',cmap='turbo'):
    """
    Assigns colors to 3D points using the 'turbo' colormap based on a scalar computed from all 3 dimensions.

    Args:
        points (np.ndarray): (N, 3) array of 3D points.
        method (str): Method for reducing 3D point to scalar. Options: 'norm', 'pca'.

    Returns:
        np.ndarray: (N, 3) RGB colors in [0, 1].
    """
    assert points.shape[1] == 3, "Input must be of shape (N, 3)"

    if method == 'norm':
        scalar = np.linalg.norm(points, axis=1)
    elif method == 'pca':
        # Project onto first principal component
        mean = points.mean(axis=0)
        centered = points - mean
        u, s, vh = np.linalg.svd(centered, full_matrices=False)
        scalar = centered @ vh[0]  # Project onto first principal axis
    else:
        raise ValueError(f"Unknown method '{method}'")

    # Normalize scalar to [0, 1]
    scalar_min, scalar_max = scalar.min(), scalar.max()
    normalized = (scalar - scalar_min) / (scalar_max - scalar_min + 1e-8)

    # Apply turbo colormap
    cmap = plt.colormaps.get_cmap(cmap)
    colors = cmap(normalized)[:, :3]  # Drop alpha

    return colors


def transform_pointmap(pointmap_cam,c2w):
    # pointmap: shape H x W x 3
    # cw2: shape 4 x 4
    # we want to transform the pointmap to the world frame
    pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
    pointmap_world_h = pointmap_cam_h @ c2w.T
    pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
    return pointmap_world

def filter_all_masks(pred_dict, batch, max_outlier_views=1):
    pred_masks = (torch.sigmoid(pred_dict['classifier'][0]).float() < 0.5).bool()  # [V, H, W]
    n_views, H, W = pred_masks.shape
    device = pred_masks.device

    K = batch['input_cams']['Ks'][0][0]  # [3, 3]
    c2ws = batch['new_cams']['c2ws'][0]  # [V, 4, 4]
    w2cs = torch.linalg.inv(c2ws)        # [V, 4, 4]

    pointmaps = pred_dict['pointmaps'][0]  # [V, H, W, 3]
    pointmaps_h = torch.cat([pointmaps, torch.ones_like(pointmaps[..., :1])], dim=-1)  # [V, H, W, 4]

    visibility_count = torch.zeros((n_views, H, W), dtype=torch.int32, device=device)

    for j in range(n_views):
        # Project pointmap j to all other views i ≠ j
        pmap_h = pointmaps_h[j]  # [H, W, 4], world-space points from view j
        pmap_h = pmap_h.view(1, H, W, 4).expand(n_views, -1, -1, -1)  # [V, H, W, 4]

        # Compute T_{i←j} = w2cs[i] @ c2ws[j]
        T = w2cs @ c2ws[j]  # [V, 4, 4]
        T = T.view(n_views, 1, 1, 4, 4)  # [V, 1, 1, 4, 4]

        # Transform to i-th camera frame
        pts_cam = torch.matmul(T, pmap_h.unsqueeze(-1)).squeeze(-1)[..., :3]  # [V, H, W, 3]

        # Project to image
        img_coords = torch.matmul(pts_cam, K.T)  # [V, H, W, 3]
        img_coords = img_coords[..., :2] / img_coords[..., 2:3].clamp(min=1e-6)
        img_coords = img_coords.round().long()  # [V, H, W, 2]

        x = img_coords[..., 0].clamp(0, W - 1)
        y = img_coords[..., 1].clamp(0, H - 1)
        valid = (img_coords[..., 0] >= 0) & (img_coords[..., 0] < W) & \
                (img_coords[..., 1] >= 0) & (img_coords[..., 1] < H)

        # Get depth of the reprojected point from j into i
        reprojected_depth = pts_cam[..., 2]  # [V, H, W]

        # Get depth of each view's original pointmap
        target_depth = pointmaps[:, :, :, 2]  # [V, H, W]

        # Lookup the depth value in view i at the projected location (x, y)
        depth_at_pixel = target_depth[torch.arange(n_views).view(-1, 1, 1), y, x]  # [V, H, W]

        # Check that the point is in front (closest along ray)
        is_closest = reprojected_depth < depth_at_pixel  # [V, H, W]

        # Lookup mask values at projected location
        projected_mask = pred_masks[torch.arange(n_views).view(-1, 1, 1), y, x] & valid  # [V, H, W]

        # Only consider as visible if it’s within mask and closest point
        visible = projected_mask & is_closest  # [V, H, W]

        # Count how many views see each pixel from j
        visibility_count[j] = visible.sum(dim=0)

    visibility_mask = (visibility_count <= max_outlier_views).bool()
    batch['new_cams']['valid_masks'] = visibility_mask & batch['new_cams']['valid_masks']
    return batch