|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from scipy.stats import norm, lognorm |
|
import torch |
|
|
|
def colorize_points_with_turbo_all_dims(points, method='norm',cmap='turbo'): |
|
""" |
|
Assigns colors to 3D points using the 'turbo' colormap based on a scalar computed from all 3 dimensions. |
|
|
|
Args: |
|
points (np.ndarray): (N, 3) array of 3D points. |
|
method (str): Method for reducing 3D point to scalar. Options: 'norm', 'pca'. |
|
|
|
Returns: |
|
np.ndarray: (N, 3) RGB colors in [0, 1]. |
|
""" |
|
assert points.shape[1] == 3, "Input must be of shape (N, 3)" |
|
|
|
if method == 'norm': |
|
scalar = np.linalg.norm(points, axis=1) |
|
elif method == 'pca': |
|
|
|
mean = points.mean(axis=0) |
|
centered = points - mean |
|
u, s, vh = np.linalg.svd(centered, full_matrices=False) |
|
scalar = centered @ vh[0] |
|
else: |
|
raise ValueError(f"Unknown method '{method}'") |
|
|
|
|
|
scalar_min, scalar_max = scalar.min(), scalar.max() |
|
normalized = (scalar - scalar_min) / (scalar_max - scalar_min + 1e-8) |
|
|
|
|
|
cmap = plt.colormaps.get_cmap(cmap) |
|
colors = cmap(normalized)[:, :3] |
|
|
|
return colors |
|
|
|
|
|
def transform_pointmap(pointmap_cam,c2w): |
|
|
|
|
|
|
|
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1) |
|
pointmap_world_h = pointmap_cam_h @ c2w.T |
|
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4] |
|
return pointmap_world |
|
|
|
def filter_all_masks(pred_dict, batch, max_outlier_views=1): |
|
pred_masks = (torch.sigmoid(pred_dict['classifier'][0]).float() < 0.5).bool() |
|
n_views, H, W = pred_masks.shape |
|
device = pred_masks.device |
|
|
|
K = batch['input_cams']['Ks'][0][0] |
|
c2ws = batch['new_cams']['c2ws'][0] |
|
w2cs = torch.linalg.inv(c2ws) |
|
|
|
pointmaps = pred_dict['pointmaps'][0] |
|
pointmaps_h = torch.cat([pointmaps, torch.ones_like(pointmaps[..., :1])], dim=-1) |
|
|
|
visibility_count = torch.zeros((n_views, H, W), dtype=torch.int32, device=device) |
|
|
|
for j in range(n_views): |
|
|
|
pmap_h = pointmaps_h[j] |
|
pmap_h = pmap_h.view(1, H, W, 4).expand(n_views, -1, -1, -1) |
|
|
|
|
|
T = w2cs @ c2ws[j] |
|
T = T.view(n_views, 1, 1, 4, 4) |
|
|
|
|
|
pts_cam = torch.matmul(T, pmap_h.unsqueeze(-1)).squeeze(-1)[..., :3] |
|
|
|
|
|
img_coords = torch.matmul(pts_cam, K.T) |
|
img_coords = img_coords[..., :2] / img_coords[..., 2:3].clamp(min=1e-6) |
|
img_coords = img_coords.round().long() |
|
|
|
x = img_coords[..., 0].clamp(0, W - 1) |
|
y = img_coords[..., 1].clamp(0, H - 1) |
|
valid = (img_coords[..., 0] >= 0) & (img_coords[..., 0] < W) & \ |
|
(img_coords[..., 1] >= 0) & (img_coords[..., 1] < H) |
|
|
|
|
|
reprojected_depth = pts_cam[..., 2] |
|
|
|
|
|
target_depth = pointmaps[:, :, :, 2] |
|
|
|
|
|
depth_at_pixel = target_depth[torch.arange(n_views).view(-1, 1, 1), y, x] |
|
|
|
|
|
is_closest = reprojected_depth < depth_at_pixel |
|
|
|
|
|
projected_mask = pred_masks[torch.arange(n_views).view(-1, 1, 1), y, x] & valid |
|
|
|
|
|
visible = projected_mask & is_closest |
|
|
|
|
|
visibility_count[j] = visible.sum(dim=0) |
|
|
|
visibility_mask = (visibility_count <= max_outlier_views).bool() |
|
batch['new_cams']['valid_masks'] = visibility_mask & batch['new_cams']['valid_masks'] |
|
return batch |