rayst3r / utils /geometry.py
bartduis's picture
init
70d1188
raw
history blame
7.43 kB
import numpy as np
import torch
import copy
from utils.utils import invalid_to_nans, invalid_to_zeros
def compute_pointmap(depth, cam2w, intrinsics):
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
h, w = depth.shape
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
x_cam = (i - cx) * depth / fx
y_cam = (j - cy) * depth / fy
points_cam = np.stack([x_cam, y_cam, depth], axis=-1)
points_world = np.dot(cam2w[:3, :3], points_cam.reshape(-1, 3).T).T + cam2w[:3, 3]
points_world = points_world.reshape(h, w, 3)
return points_world
def invert_poses(raw_poses):
poses = copy.deepcopy(raw_poses)
original_shape = poses.shape
poses = poses.reshape(-1, 4, 4)
R = copy.deepcopy(poses[:, :3, :3])
t = copy.deepcopy(poses[:, :3, 3])
poses[:, :3, :3] = R.transpose(1, 2)
poses[:, :3, 3] = torch.bmm(-R.transpose(1, 2), t.unsqueeze(-1)).squeeze(-1)
poses = poses.reshape(*original_shape)
return poses
def center_pointmaps_set(dict,w2cs):
swap_dim = False
if dict["pointmaps"].shape[1] == 3:
swap_dim = True
dict["pointmaps"] = dict["pointmaps"].transpose(1,-1)
original_shape = dict["pointmaps"].shape
device = dict["pointmaps"].device
B = original_shape[0]
# recompute pointmaps in camera frame
pointmaps = dict["pointmaps"]
pointmaps_h = torch.cat([pointmaps,torch.ones(pointmaps.shape[:-1]+(1,)).to(device)],dim=-1)
pointmaps_h = pointmaps_h.reshape(B,-1,4)
pointmaps_recentered_h = torch.bmm(w2cs,pointmaps_h.transpose(1,2)).transpose(1,2)
pointmaps_recentered = pointmaps_recentered_h[...,:3]/pointmaps_recentered_h[...,3:4]
pointmaps_recentered = pointmaps_recentered.reshape(*original_shape)
# recompute c2ws
if "c2ws" in dict:
c2ws_recentered = torch.bmm(w2cs,dict["c2ws"].reshape(-1,4,4))
c2ws_recentered = c2ws_recentered.reshape(dict["c2ws"].shape)
dict["c2ws"] = c2ws_recentered
# assign to dict
dict["pointmaps"] = pointmaps_recentered
if swap_dim:
dict["pointmaps"] = dict["pointmaps"].transpose(1,-1)
return dict
def center_pointmaps(batch):
original_poses = batch["new_cams"]["c2ws"] # assuming first camera is the one we want to predict
w2cs = invert_poses(batch["new_cams"]["c2ws"])
batch["new_cams"] = center_pointmaps_set(batch["new_cams"],w2cs)
batch["input_cams"] = center_pointmaps_set(batch["input_cams"],w2cs)
batch["original_poses"] = original_poses
return batch
def uncenter_pointmaps(pred,gt,batch):
original_poses = batch["original_poses"]
batch["new_cams"] = center_pointmaps_set(batch["new_cams"],original_poses)
batch["input_cams"] = center_pointmaps_set(batch["input_cams"],original_poses)
#gt = center_pointmaps_set(gt,original_poses)
#pred = center_pointmaps_set(pred,original_poses)
return pred, gt, batch
def compute_rays(batch):
h, w = batch["new_cams"]["pointmaps"].shape[-3:-1]
B = batch["new_cams"]["pointmaps"].shape[0]
device = batch["new_cams"]["pointmaps"].device
Ks = batch["new_cams"]["Ks"]
i_s, j_s = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
i_s, j_s = torch.tensor(i_s).repeat(B,1,1).to(device), torch.tensor(j_s).repeat(B,1,1).to(device)
f_x = Ks[:,0,0].reshape(-1,1,1)
f_y = Ks[:,1,1].reshape(-1,1,1)
c_x = Ks[:,0,2].reshape(-1,1,1)
c_y = Ks[:,1,2].reshape(-1,1,1)
# compute rays with z=1
x_cam = (i_s - c_x) / f_x
y_cam = (j_s - c_y) / f_y
rays = torch.cat([x_cam.unsqueeze(-1),y_cam.unsqueeze(-1)],dim=-1)
return rays
def normalize_pointcloud(pts1, pts2=None, norm_mode='avg_dis', valid1=None, valid2=None, valid3=None, ret_factor=False,pts3=None):
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
norm_mode, dis_mode = norm_mode.split('_')
if norm_mode == 'avg':
# gather all points together (joint normalization)
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
if pts3 is not None:
nan_pts3, nnz3 = invalid_to_zeros(pts3, valid3, ndim=3)
all_pts = torch.cat((all_pts, nan_pts3), dim=1)
nnz1 += nnz3
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == 'dis':
pass # do nothing
elif dis_mode == 'log1p':
all_dis = torch.log1p(all_dis)
elif dis_mode == 'warp-log1p':
# actually warp input points before normalizing them
log_dis = torch.log1p(all_dis)
warp_factor = log_dis / all_dis.clip(min=1e-8)
H1, W1 = pts1.shape[1:-1]
pts1 = pts1 * warp_factor[:,:W1*H1].view(-1,H1,W1,1)
if pts2 is not None:
H2, W2 = pts2.shape[1:-1]
pts2 = pts2 * warp_factor[:,W1*H1:].view(-1,H2,W2,1)
all_dis = log_dis # this is their true distance afterwards
else:
raise ValueError(f'bad {dis_mode=}')
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
else:
# gather all points together (joint normalization)
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if norm_mode == 'avg':
norm_factor = all_dis.nanmean(dim=1)
elif norm_mode == 'median':
norm_factor = all_dis.nanmedian(dim=1).values.detach()
elif norm_mode == 'sqrt':
norm_factor = all_dis.sqrt().nanmean(dim=1)**2
else:
raise ValueError(f'bad {norm_mode=}')
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts1.ndim:
norm_factor.unsqueeze_(-1)
res = (pts1 / norm_factor,)
if pts2 is not None:
res = res + (pts2 / norm_factor,)
if pts3 is not None:
res = res + (pts3 / norm_factor,)
if ret_factor:
res = res + (norm_factor,)
return res
def compute_pointmap_torch(depth, cam2w, intrinsics,device='cuda'):
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
h, w = depth.shape
#i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
i, j = torch.meshgrid(torch.arange(w).to(device), torch.arange(h).to(device), indexing='xy')
x_cam = (i - cx) * depth / fx
y_cam = (j - cy) * depth / fy
points_cam = torch.stack([x_cam, y_cam, depth], dim=-1)
points_world = (cam2w[:3, :3] @ points_cam.reshape(-1, 3).T).T + cam2w[:3, 3]
points_world = points_world.reshape(h, w, 3)
return points_world
def depth2pts(depths, Ks):
"""
Convert depth map to 3D points
"""
device = depths.device
B = depths.shape[0]
pts = []
for b in range(B):
depth_b = depths[b]
K = Ks[b]
pts.append(compute_pointmap_torch(depth_b,torch.eye(4).to(device), K,device))
pts = torch.stack(pts, dim=0)
return pts