Spaces:
Running
on
L4
Running
on
L4
import torch.nn as nn | |
import torch | |
import roma | |
import numpy as np | |
import cv2 | |
from functools import cache | |
def todevice(batch, device, callback=None, non_blocking=False): | |
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). | |
batch: list, tuple, dict of tensors or other things | |
device: pytorch device or 'numpy' | |
callback: function that would be called on every sub-elements. | |
""" | |
if callback: | |
batch = callback(batch) | |
if isinstance(batch, dict): | |
return {k: todevice(v, device) for k, v in batch.items()} | |
if isinstance(batch, (tuple, list)): | |
return type(batch)(todevice(x, device) for x in batch) | |
x = batch | |
if device == "numpy": | |
if isinstance(x, torch.Tensor): | |
x = x.detach().cpu().numpy() | |
elif x is not None: | |
if isinstance(x, np.ndarray): | |
x = torch.from_numpy(x) | |
if torch.is_tensor(x): | |
x = x.to(device, non_blocking=non_blocking) | |
return x | |
to_device = todevice # alias | |
def to_numpy(x): | |
return todevice(x, "numpy") | |
def to_cpu(x): | |
return todevice(x, "cpu") | |
def to_cuda(x): | |
return todevice(x, "cuda") | |
def signed_log1p(x): | |
sign = torch.sign(x) | |
return sign * torch.log1p(torch.abs(x)) | |
def l2_dist(a, b, weight): | |
return (a - b).square().sum(dim=-1) * weight | |
def l1_dist(a, b, weight): | |
return (a - b).norm(dim=-1) * weight | |
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) | |
def _check_edges(edges): | |
indices = sorted({i for edge in edges for i in edge}) | |
assert indices == list(range(len(indices))), "bad pair indices: missing values " | |
return len(indices) | |
def NoGradParamDict(x): | |
assert isinstance(x, dict) | |
return nn.ParameterDict(x).requires_grad_(False) | |
def edge_str(i, j): | |
return f"{i}_{j}" | |
def i_j_ij(ij): | |
# inputs are (i, j) | |
return edge_str(*ij), ij | |
def edge_conf(conf_i, conf_j): | |
score = float(conf_i.mean() * conf_j.mean()) | |
return score | |
def get_imshapes(edges, pred_i, pred_j): | |
n_imgs = max(max(e) for e in edges) + 1 | |
imshapes = [None] * n_imgs | |
for e, (i, j) in enumerate(edges): | |
shape_i = tuple(pred_i[e]["pts3d_is_self_view"].shape[0:2]) | |
shape_j = tuple(pred_j[e]["pts3d_in_other_view"].shape[0:2]) | |
if imshapes[i]: | |
assert imshapes[i] == shape_i, f"incorrect shape for image {i}" | |
if imshapes[j]: | |
assert imshapes[j] == shape_j, f"incorrect shape for image {j}" | |
imshapes[i] = shape_i | |
imshapes[j] = shape_j | |
return imshapes | |
def get_conf_trf(mode): | |
if mode == "log": | |
def conf_trf(x): | |
return x.log() | |
elif mode == "sqrt": | |
def conf_trf(x): | |
return x.sqrt() | |
elif mode == "m1": | |
def conf_trf(x): | |
return x - 1 | |
elif mode in ("id", "none"): | |
def conf_trf(x): | |
return x | |
else: | |
raise ValueError(f"bad mode for {mode=}") | |
return conf_trf | |
def _compute_img_conf(imshapes, device, edges, edge2conf_i, edge2conf_j): | |
im_conf = nn.ParameterList([torch.zeros(hw, device=device) for hw in imshapes]) | |
for e, (i, j) in enumerate(edges): | |
im_conf[i] = torch.maximum(im_conf[i], edge2conf_i[edge_str(i, j)]) | |
im_conf[j] = torch.maximum(im_conf[j], edge2conf_j[edge_str(i, j)]) | |
return im_conf | |
def xy_grid( | |
W, | |
H, | |
device=None, | |
origin=(0, 0), | |
unsqueeze=None, | |
cat_dim=-1, | |
homogeneous=False, | |
**arange_kw, | |
): | |
"""Output a (H,W,2) array of int32 | |
with output[j,i,0] = i + origin[0] | |
output[j,i,1] = j + origin[1] | |
""" | |
if device is None: | |
# numpy | |
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones | |
else: | |
# torch | |
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) | |
meshgrid, stack = torch.meshgrid, torch.stack | |
ones = lambda *a: torch.ones(*a, device=device) | |
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] | |
grid = meshgrid(tw, th, indexing="xy") | |
if homogeneous: | |
grid = grid + (ones((H, W)),) | |
if unsqueeze is not None: | |
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) | |
if cat_dim is not None: | |
grid = stack(grid, cat_dim) | |
return grid | |
def estimate_focal_knowing_depth( | |
pts3d, pp, focal_mode="median", min_focal=0.0, max_focal=np.inf | |
): | |
"""Reprojection method, for when the absolute depth is known: | |
1) estimate the camera focal using a robust estimator | |
2) reproject points onto true rays, minimizing a certain error | |
""" | |
B, H, W, THREE = pts3d.shape | |
assert THREE == 3 | |
# centered pixel grid | |
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view( | |
-1, 1, 2 | |
) # B,HW,2 | |
pts3d = pts3d.flatten(1, 2) # (B, HW, 3) | |
if focal_mode == "median": | |
with torch.no_grad(): | |
# direct estimation of focal | |
u, v = pixels.unbind(dim=-1) | |
x, y, z = pts3d.unbind(dim=-1) | |
fx_votes = (u * z) / x | |
fy_votes = (v * z) / y | |
# assume square pixels, hence same focal for X and Y | |
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) | |
focal = torch.nanmedian(f_votes, dim=-1).values | |
elif focal_mode == "weiszfeld": | |
# init focal with l2 closed form | |
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z| | |
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num( | |
posinf=0, neginf=0 | |
) # homogeneous (x,y,1) | |
dot_xy_px = (xy_over_z * pixels).sum(dim=-1) | |
dot_xy_xy = xy_over_z.square().sum(dim=-1) | |
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) | |
# iterative re-weighted least-squares | |
for iter in range(10): | |
# re-weighting by inverse of distance | |
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) | |
# print(dis.nanmean(-1)) | |
w = dis.clip(min=1e-8).reciprocal() | |
# update the scaling with the new weights | |
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) | |
else: | |
raise ValueError(f"bad {focal_mode=}") | |
focal_base = max(H, W) / ( | |
2 * np.tan(np.deg2rad(60) / 2) | |
) # size / 1.1547005383792515 | |
focal = focal.clip(min=min_focal * focal_base, max=max_focal * focal_base) | |
# print(focal) | |
return focal | |
def estimate_focal(pts3d_i, pp=None): | |
if pp is None: | |
H, W, THREE = pts3d_i.shape | |
assert THREE == 3 | |
pp = torch.tensor((W / 2, H / 2), device=pts3d_i.device) | |
focal = estimate_focal_knowing_depth( | |
pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode="weiszfeld" | |
).ravel() | |
return float(focal) | |
def rigid_points_registration(pts1, pts2, conf): | |
R, T, s = roma.rigid_points_registration( | |
pts1.reshape(-1, 3), | |
pts2.reshape(-1, 3), | |
weights=conf.ravel(), | |
compute_scaling=True, | |
) | |
return s, R, T # return un-scaled (R, T) | |
def sRT_to_4x4(scale, R, T, device): | |
trf = torch.eye(4, device=device) | |
trf[:3, :3] = R * scale | |
trf[:3, 3] = T.ravel() # doesn't need scaling | |
return trf | |
def geotrf(Trf, pts, ncol=None, norm=False): | |
"""Apply a geometric transformation to a list of 3-D points. | |
H: 3x3 or 4x4 projection matrix (typically a Homography) | |
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) | |
ncol: int. number of columns of the result (2 or 3) | |
norm: float. if != 0, the resut is projected on the z=norm plane. | |
Returns an array of projected 2d points. | |
""" | |
assert Trf.ndim >= 2 | |
if isinstance(Trf, np.ndarray): | |
pts = np.asarray(pts) | |
elif isinstance(Trf, torch.Tensor): | |
pts = torch.as_tensor(pts, dtype=Trf.dtype) | |
# adapt shape if necessary | |
output_reshape = pts.shape[:-1] | |
ncol = ncol or pts.shape[-1] | |
# optimized code | |
if ( | |
isinstance(Trf, torch.Tensor) | |
and isinstance(pts, torch.Tensor) | |
and Trf.ndim == 3 | |
and pts.ndim == 4 | |
): | |
d = pts.shape[3] | |
if Trf.shape[-1] == d: | |
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) | |
elif Trf.shape[-1] == d + 1: | |
pts = ( | |
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) | |
+ Trf[:, None, None, :d, d] | |
) | |
else: | |
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") | |
else: | |
if Trf.ndim >= 3: | |
n = Trf.ndim - 2 | |
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" | |
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) | |
if pts.ndim > Trf.ndim: | |
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) | |
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) | |
elif pts.ndim == 2: | |
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) | |
pts = pts[:, None, :] | |
if pts.shape[-1] + 1 == Trf.shape[-1]: | |
Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] | |
elif pts.shape[-1] == Trf.shape[-1]: | |
Trf = Trf.swapaxes(-1, -2) # transpose Trf | |
pts = pts @ Trf | |
else: | |
pts = Trf @ pts.T | |
if pts.ndim >= 2: | |
pts = pts.swapaxes(-1, -2) | |
if norm: | |
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG | |
if norm != 1: | |
pts *= norm | |
res = pts[..., :ncol].reshape(*output_reshape, ncol) | |
return res | |
def inv(mat): | |
"""Invert a torch or numpy matrix""" | |
if isinstance(mat, torch.Tensor): | |
return torch.linalg.inv(mat) | |
if isinstance(mat, np.ndarray): | |
return np.linalg.inv(mat) | |
raise ValueError(f"bad matrix type = {type(mat)}") | |
def pixel_grid(H, W): | |
return np.mgrid[:W, :H].T.astype(np.float32) | |
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): | |
# extract camera poses and focals with RANSAC-PnP | |
if msk.sum() < 4: | |
return None # we need at least 4 points for PnP | |
pts3d, msk = map(to_numpy, (pts3d, msk)) | |
H, W, THREE = pts3d.shape | |
assert THREE == 3 | |
pixels = pixel_grid(H, W) | |
if focal is None: | |
S = max(W, H) | |
tentative_focals = np.geomspace(S / 2, S * 3, 21) | |
else: | |
tentative_focals = [focal] | |
if pp is None: | |
pp = (W / 2, H / 2) | |
else: | |
pp = to_numpy(pp) | |
best = (0,) | |
for focal in tentative_focals: | |
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) | |
success, R, T, inliers = cv2.solvePnPRansac( | |
pts3d[msk], | |
pixels[msk], | |
K, | |
None, | |
iterationsCount=niter_PnP, | |
reprojectionError=5, | |
flags=cv2.SOLVEPNP_SQPNP, | |
) | |
if not success: | |
continue | |
score = len(inliers) | |
if success and score > best[0]: | |
best = score, R, T, focal | |
if not best[0]: | |
return None | |
_, R, T, best_focal = best | |
R = cv2.Rodrigues(R)[0] # world to cam | |
R, T = map(torch.from_numpy, (R, T)) | |
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world | |
def get_med_dist_between_poses(poses): | |
from scipy.spatial.distance import pdist | |
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) | |
def align_multiple_poses(src_poses, target_poses): | |
N = len(src_poses) | |
assert src_poses.shape == target_poses.shape == (N, 4, 4) | |
def center_and_z(poses): | |
eps = get_med_dist_between_poses(poses) / 100 | |
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps * poses[:, :3, 2])) | |
R, T, s = roma.rigid_points_registration( | |
center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True | |
) | |
return s, R, T | |
def cosine_schedule(t, lr_start, lr_end): | |
assert 0 <= t <= 1 | |
return lr_end + (lr_start - lr_end) * (1 + np.cos(t * np.pi)) / 2 | |
def linear_schedule(t, lr_start, lr_end): | |
assert 0 <= t <= 1 | |
return lr_start + (lr_end - lr_start) * t | |
def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2): | |
assert 0 <= t <= 1 | |
cycle_t = t * num_cycles | |
cycle_t = cycle_t - int(cycle_t) | |
if t == 1: | |
cycle_t = 1 | |
return linear_schedule(cycle_t, lr_start, lr_end) | |
def adjust_learning_rate_by_lr(optimizer, lr): | |
for param_group in optimizer.param_groups: | |
if "lr_scale" in param_group: | |
param_group["lr"] = lr * param_group["lr_scale"] | |
else: | |
param_group["lr"] = lr | |