liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import torch.nn as nn
import torch
import roma
import numpy as np
import cv2
from functools import cache
def todevice(batch, device, callback=None, non_blocking=False):
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
batch: list, tuple, dict of tensors or other things
device: pytorch device or 'numpy'
callback: function that would be called on every sub-elements.
"""
if callback:
batch = callback(batch)
if isinstance(batch, dict):
return {k: todevice(v, device) for k, v in batch.items()}
if isinstance(batch, (tuple, list)):
return type(batch)(todevice(x, device) for x in batch)
x = batch
if device == "numpy":
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif x is not None:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if torch.is_tensor(x):
x = x.to(device, non_blocking=non_blocking)
return x
to_device = todevice # alias
def to_numpy(x):
return todevice(x, "numpy")
def to_cpu(x):
return todevice(x, "cpu")
def to_cuda(x):
return todevice(x, "cuda")
def signed_log1p(x):
sign = torch.sign(x)
return sign * torch.log1p(torch.abs(x))
def l2_dist(a, b, weight):
return (a - b).square().sum(dim=-1) * weight
def l1_dist(a, b, weight):
return (a - b).norm(dim=-1) * weight
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
def _check_edges(edges):
indices = sorted({i for edge in edges for i in edge})
assert indices == list(range(len(indices))), "bad pair indices: missing values "
return len(indices)
def NoGradParamDict(x):
assert isinstance(x, dict)
return nn.ParameterDict(x).requires_grad_(False)
def edge_str(i, j):
return f"{i}_{j}"
def i_j_ij(ij):
# inputs are (i, j)
return edge_str(*ij), ij
def edge_conf(conf_i, conf_j):
score = float(conf_i.mean() * conf_j.mean())
return score
def get_imshapes(edges, pred_i, pred_j):
n_imgs = max(max(e) for e in edges) + 1
imshapes = [None] * n_imgs
for e, (i, j) in enumerate(edges):
shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2])
shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2])
if imshapes[i]:
assert imshapes[i] == shape_i, f"incorrect shape for image {i}"
if imshapes[j]:
assert imshapes[j] == shape_j, f"incorrect shape for image {j}"
imshapes[i] = shape_i
imshapes[j] = shape_j
return imshapes
def get_conf_trf(mode):
if mode == "log":
def conf_trf(x):
return x.log()
elif mode == "sqrt":
def conf_trf(x):
return x.sqrt()
elif mode == "m1":
def conf_trf(x):
return x - 1
elif mode in ("id", "none"):
def conf_trf(x):
return x
else:
raise ValueError(f"bad mode for {mode=}")
return conf_trf
@torch.no_grad()
def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j):
im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes])
for e, (i, j) in enumerate(edges):
im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)])
im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)])
return im_conf
def xy_grid(
W,
H,
device=None,
origin=(0, 0),
unsqueeze=None,
cat_dim=-1,
homogeneous=False,
**arange_kw,
):
"""Output a (H,W,2) array of int32
with output[j,i,0] = i + origin[0]
output[j,i,1] = j + origin[1]
"""
if device is None:
# numpy
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
else:
# torch
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
meshgrid, stack = torch.meshgrid, torch.stack
ones = lambda *a: torch.ones(*a, device=device)
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
grid = meshgrid(tw, th, indexing="xy")
if homogeneous:
grid = grid + (ones((H, W)),)
if unsqueeze is not None:
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
if cat_dim is not None:
grid = stack(grid, cat_dim)
return grid
def estimate_focal_knowing_depth(
pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf
):
"""Reprojection method, for when the absolute depth is known:
1) estimate the camera focal using a robust estimator
2) reproject points onto true rays, minimizing a certain error
"""
B, H, W, THREE = pts3d.shape
assert THREE == 3
# centered pixel grid
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(
-1, 1, 2
) # B,HW,2
pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
if focal_mode == "median":
with torch.no_grad():
# direct estimation of focal
u, v = pixels.unbind(dim=-1)
x, y, z = pts3d.unbind(dim=-1)
fx_votes = (u * z) / x
fy_votes = (v * z) / y
# assume square pixels, hence same focal for X and Y
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
focal = torch.nanmedian(f_votes, dim=-1).values
elif focal_mode == "weiszfeld":
# init focal with l2 closed form
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(
posinf=0, neginf=0
) # homogeneous (x,y,1)
dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
dot_xy_xy = xy_over_z.square().sum(dim=-1)
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
# iterative re-weighted least-squares
for iter in range(10):
# re-weighting by inverse of distance
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
# print(dis.nanmean(-1))
w = dis.clip(min=1e-8).reciprocal()
# update the scaling with the new weights
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
else:
raise ValueError(f"bad {focal_mode=}")
focal_base = max(H, W) / (
2 * np.tan(np.deg2rad(60) / 2)
) # size / 1.1547005383792515
focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base)
# print(focal)
return focal
def estimate_focal(pts3d_i, pp=None):
if pp is None:
H, W, THREE = pts3d_i.shape
assert THREE == 3
pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device)
focal = estimate_focal_knowing_depth(
pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld"
).ravel()
return float(focal)
def rigid_points_registration(pts1, pts2, conf):
R, T, s = roma.rigid_points_registration(
pts1.reshape(-1, 3),
pts2.reshape(-1, 3),
weights=conf.ravel(),
compute_scaling=True,
)
return s, R, T # return un-scaled (R, T)
def sRT_to_4x4(scale, R, T, device):
trf = torch.eye(4, device=device)
trf[:3, :3] = R * scale
trf[:3, 3] = T.ravel() # doesn't need scaling
return trf
def geotrf(Trf, pts, ncol=None, norm=False):
"""Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (
isinstance(Trf, torch.Tensor)
and isinstance(pts, torch.Tensor)
and Trf.ndim == 3
and pts.ndim == 4
):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = (
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ Trf[:, None, None, :d, d]
)
else:
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def inv(mat):
"""Invert a torch or numpy matrix"""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f"bad matrix type = {type(mat)}")
@cache
def pixel_grid(H, W):
return np.mgrid[:W, :H].T.astype(np.float32)
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
# extract camera poses and focals with RANSAC-PnP
if msk.sum() < 4:
return None # we need at least 4 points for PnP
pts3d, msk = map(to_numpy, (pts3d, msk))
H, W, THREE = pts3d.shape
assert THREE == 3
pixels = pixel_grid(H, W)
if focal is None:
S = max(W, H)
tentative_focals = np.geomspace(S / 2, S * 3, 21)
else:
tentative_focals = [focal]
if pp is None:
pp = (W / 2, H / 2)
else:
pp = to_numpy(pp)
best = (0,)
for focal in tentative_focals:
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
success, R, T, inliers = cv2.solvePnPRansac(
pts3d[msk],
pixels[msk],
K,
None,
iterationsCount=niter_PnP,
reprojectionError=5,
flags=cv2.SOLVEPNP_SQPNP,
)
if not success:
continue
score = len(inliers)
if success and score > best[0]:
best = score, R, T, focal
if not best[0]:
return None
_, R, T, best_focal = best
R = cv2.Rodrigues(R)[0] # world to cam
R, T = map(torch.from_numpy, (R, T))
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
def get_med_dist_between_poses(poses):
from scipy.spatial.distance import pdist
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
def align_multiple_poses(src_poses, target_poses):
N = len(src_poses)
assert src_poses.shape == target_poses.shape == (N, 4, 4)
def center_and_z(poses):
eps = get_med_dist_between_poses(poses) / 100
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2]))
R, T, s = roma.rigid_points_registration(
center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True
)
return s, R, T
def cosine_schedule(t, lr_start, lr_end):
assert 0 <= t <= 1
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2
def linear_schedule(t, lr_start, lr_end):
assert 0 <= t <= 1
return lr_start + (lr_end - lr_start) * t
def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2):
assert 0 <= t <= 1
cycle_t = t * num_cycles
cycle_t = cycle_t - int(cycle_t)
if t == 1:
cycle_t = 1
return linear_schedule(cycle_t, lr_start, lr_end)
def adjust_learning_rate_by_lr(optimizer, lr):
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr