|
import numpy as np |
|
import torch |
|
import copy |
|
from utils.utils import invalid_to_nans, invalid_to_zeros |
|
|
|
def compute_pointmap(depth, cam2w, intrinsics): |
|
fx, fy = intrinsics[0, 0], intrinsics[1, 1] |
|
cx, cy = intrinsics[0, 2], intrinsics[1, 2] |
|
h, w = depth.shape |
|
|
|
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') |
|
|
|
x_cam = (i - cx) * depth / fx |
|
y_cam = (j - cy) * depth / fy |
|
|
|
points_cam = np.stack([x_cam, y_cam, depth], axis=-1) |
|
points_world = np.dot(cam2w[:3, :3], points_cam.reshape(-1, 3).T).T + cam2w[:3, 3] |
|
points_world = points_world.reshape(h, w, 3) |
|
|
|
return points_world |
|
|
|
def invert_poses(raw_poses): |
|
poses = copy.deepcopy(raw_poses) |
|
original_shape = poses.shape |
|
poses = poses.reshape(-1, 4, 4) |
|
R = copy.deepcopy(poses[:, :3, :3]) |
|
t = copy.deepcopy(poses[:, :3, 3]) |
|
poses[:, :3, :3] = R.transpose(1, 2) |
|
poses[:, :3, 3] = torch.bmm(-R.transpose(1, 2), t.unsqueeze(-1)).squeeze(-1) |
|
poses = poses.reshape(*original_shape) |
|
return poses |
|
|
|
def center_pointmaps_set(dict,w2cs): |
|
swap_dim = False |
|
if dict["pointmaps"].shape[1] == 3: |
|
swap_dim = True |
|
dict["pointmaps"] = dict["pointmaps"].transpose(1,-1) |
|
|
|
original_shape = dict["pointmaps"].shape |
|
device = dict["pointmaps"].device |
|
B = original_shape[0] |
|
|
|
|
|
pointmaps = dict["pointmaps"] |
|
pointmaps_h = torch.cat([pointmaps,torch.ones(pointmaps.shape[:-1]+(1,)).to(device)],dim=-1) |
|
pointmaps_h = pointmaps_h.reshape(B,-1,4) |
|
pointmaps_recentered_h = torch.bmm(w2cs,pointmaps_h.transpose(1,2)).transpose(1,2) |
|
pointmaps_recentered = pointmaps_recentered_h[...,:3]/pointmaps_recentered_h[...,3:4] |
|
pointmaps_recentered = pointmaps_recentered.reshape(*original_shape) |
|
|
|
|
|
if "c2ws" in dict: |
|
c2ws_recentered = torch.bmm(w2cs,dict["c2ws"].reshape(-1,4,4)) |
|
c2ws_recentered = c2ws_recentered.reshape(dict["c2ws"].shape) |
|
dict["c2ws"] = c2ws_recentered |
|
|
|
|
|
dict["pointmaps"] = pointmaps_recentered |
|
if swap_dim: |
|
dict["pointmaps"] = dict["pointmaps"].transpose(1,-1) |
|
return dict |
|
|
|
def center_pointmaps(batch): |
|
original_poses = batch["new_cams"]["c2ws"] |
|
w2cs = invert_poses(batch["new_cams"]["c2ws"]) |
|
|
|
batch["new_cams"] = center_pointmaps_set(batch["new_cams"],w2cs) |
|
batch["input_cams"] = center_pointmaps_set(batch["input_cams"],w2cs) |
|
batch["original_poses"] = original_poses |
|
return batch |
|
|
|
|
|
def uncenter_pointmaps(pred,gt,batch): |
|
original_poses = batch["original_poses"] |
|
|
|
batch["new_cams"] = center_pointmaps_set(batch["new_cams"],original_poses) |
|
batch["input_cams"] = center_pointmaps_set(batch["input_cams"],original_poses) |
|
|
|
|
|
|
|
return pred, gt, batch |
|
|
|
def compute_rays(batch): |
|
h, w = batch["new_cams"]["pointmaps"].shape[-3:-1] |
|
B = batch["new_cams"]["pointmaps"].shape[0] |
|
device = batch["new_cams"]["pointmaps"].device |
|
Ks = batch["new_cams"]["Ks"] |
|
i_s, j_s = np.meshgrid(np.arange(w), np.arange(h), indexing='xy') |
|
i_s, j_s = torch.tensor(i_s).repeat(B,1,1).to(device), torch.tensor(j_s).repeat(B,1,1).to(device) |
|
|
|
f_x = Ks[:,0,0].reshape(-1,1,1) |
|
f_y = Ks[:,1,1].reshape(-1,1,1) |
|
c_x = Ks[:,0,2].reshape(-1,1,1) |
|
c_y = Ks[:,1,2].reshape(-1,1,1) |
|
|
|
|
|
x_cam = (i_s - c_x) / f_x |
|
y_cam = (j_s - c_y) / f_y |
|
rays = torch.cat([x_cam.unsqueeze(-1),y_cam.unsqueeze(-1)],dim=-1) |
|
return rays |
|
|
|
def normalize_pointcloud(pts1, pts2=None, norm_mode='avg_dis', valid1=None, valid2=None, valid3=None, ret_factor=False,pts3=None): |
|
assert pts1.ndim >= 3 and pts1.shape[-1] == 3 |
|
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) |
|
norm_mode, dis_mode = norm_mode.split('_') |
|
|
|
if norm_mode == 'avg': |
|
|
|
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) |
|
nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) |
|
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 |
|
if pts3 is not None: |
|
nan_pts3, nnz3 = invalid_to_zeros(pts3, valid3, ndim=3) |
|
all_pts = torch.cat((all_pts, nan_pts3), dim=1) |
|
nnz1 += nnz3 |
|
|
|
all_dis = all_pts.norm(dim=-1) |
|
if dis_mode == 'dis': |
|
pass |
|
elif dis_mode == 'log1p': |
|
all_dis = torch.log1p(all_dis) |
|
elif dis_mode == 'warp-log1p': |
|
|
|
log_dis = torch.log1p(all_dis) |
|
warp_factor = log_dis / all_dis.clip(min=1e-8) |
|
H1, W1 = pts1.shape[1:-1] |
|
pts1 = pts1 * warp_factor[:,:W1*H1].view(-1,H1,W1,1) |
|
if pts2 is not None: |
|
H2, W2 = pts2.shape[1:-1] |
|
pts2 = pts2 * warp_factor[:,W1*H1:].view(-1,H2,W2,1) |
|
all_dis = log_dis |
|
else: |
|
raise ValueError(f'bad {dis_mode=}') |
|
|
|
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) |
|
else: |
|
|
|
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) |
|
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None |
|
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 |
|
|
|
|
|
all_dis = all_pts.norm(dim=-1) |
|
|
|
if norm_mode == 'avg': |
|
norm_factor = all_dis.nanmean(dim=1) |
|
elif norm_mode == 'median': |
|
norm_factor = all_dis.nanmedian(dim=1).values.detach() |
|
elif norm_mode == 'sqrt': |
|
norm_factor = all_dis.sqrt().nanmean(dim=1)**2 |
|
else: |
|
raise ValueError(f'bad {norm_mode=}') |
|
|
|
norm_factor = norm_factor.clip(min=1e-8) |
|
while norm_factor.ndim < pts1.ndim: |
|
norm_factor.unsqueeze_(-1) |
|
|
|
res = (pts1 / norm_factor,) |
|
if pts2 is not None: |
|
res = res + (pts2 / norm_factor,) |
|
if pts3 is not None: |
|
res = res + (pts3 / norm_factor,) |
|
if ret_factor: |
|
res = res + (norm_factor,) |
|
return res |
|
|
|
def compute_pointmap_torch(depth, cam2w, intrinsics,device='cuda'): |
|
fx, fy = intrinsics[0, 0], intrinsics[1, 1] |
|
cx, cy = intrinsics[0, 2], intrinsics[1, 2] |
|
h, w = depth.shape |
|
|
|
|
|
i, j = torch.meshgrid(torch.arange(w).to(device), torch.arange(h).to(device), indexing='xy') |
|
x_cam = (i - cx) * depth / fx |
|
y_cam = (j - cy) * depth / fy |
|
|
|
points_cam = torch.stack([x_cam, y_cam, depth], dim=-1) |
|
points_world = (cam2w[:3, :3] @ points_cam.reshape(-1, 3).T).T + cam2w[:3, 3] |
|
points_world = points_world.reshape(h, w, 3) |
|
|
|
return points_world |
|
|
|
def depth2pts(depths, Ks): |
|
""" |
|
Convert depth map to 3D points |
|
""" |
|
device = depths.device |
|
B = depths.shape[0] |
|
pts = [] |
|
for b in range(B): |
|
depth_b = depths[b] |
|
K = Ks[b] |
|
pts.append(compute_pointmap_torch(depth_b,torch.eye(4).to(device), K,device)) |
|
pts = torch.stack(pts, dim=0) |
|
return pts |