vmem / extern /CUT3R /cloud_opt /init_all.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
from functools import cache
import numpy as np
import scipy.sparse as sp
import torch
import cv2
import roma
from tqdm import tqdm
from cloud_opt.utils import *
def compute_edge_scores(edges, edge2conf_i, edge2conf_j):
"""
edges: 'i_j', (i,j)
"""
score_dict = {
(i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges
}
return score_dict
def dict_to_sparse_graph(dic):
n_imgs = max(max(e) for e in dic) + 1
res = sp.dok_array((n_imgs, n_imgs))
for edge, value in dic.items():
res[edge] = value
return res
@torch.no_grad()
def init_minimum_spanning_tree(self, **kw):
"""Init all camera poses (image-wise and pairwise poses) given
an initial set of pairwise estimations.
"""
device = self.device
pts3d, _, im_focals, im_poses = minimum_spanning_tree(
self.imshapes,
self.edges,
self.edge2pts_i,
self.edge2pts_j,
self.edge2conf_i,
self.edge2conf_j,
self.im_conf,
self.min_conf_thr,
device,
has_im_poses=self.has_im_poses,
verbose=self.verbose,
**kw,
)
return init_from_pts3d(self, pts3d, im_focals, im_poses)
def minimum_spanning_tree(
imshapes,
edges,
edge2pred_i,
edge2pred_j,
edge2conf_i,
edge2conf_j,
im_conf,
min_conf_thr,
device,
has_im_poses=True,
niter_PnP=10,
verbose=True,
save_score_path=None,
):
n_imgs = len(imshapes)
eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j)
sparse_graph = -dict_to_sparse_graph(eadge_and_scores)
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
# temp variable to store 3d points
pts3d = [None] * len(imshapes)
todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
im_poses = [None] * n_imgs
im_focals = [None] * n_imgs
# init with strongest edge
score, i, j = todo.pop()
if verbose:
print(f" init edge ({i}*,{j}*) {score=}")
i_j = edge_str(i, j)
pts3d[i] = edge2pred_i[i_j].clone()
pts3d[j] = edge2pred_j[i_j].clone()
done = {i, j}
if has_im_poses:
im_poses[i] = torch.eye(4, device=device)
im_focals[i] = estimate_focal(edge2pred_i[i_j])
# set initial pointcloud based on pairwise graph
msp_edges = [(i, j)]
while todo:
# each time, predict the next one
score, i, j = todo.pop()
if im_focals[i] is None:
im_focals[i] = estimate_focal(edge2pred_i[i_j])
if i in done:
if verbose:
print(f" init edge ({i},{j}*) {score=}")
assert j not in done
# align pred[i] with pts3d[i], and then set j accordingly
i_j = edge_str(i, j)
s, R, T = rigid_points_registration(
edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j]
)
trf = sRT_to_4x4(s, R, T, device)
pts3d[j] = geotrf(trf, edge2pred_j[i_j])
done.add(j)
msp_edges.append((i, j))
if has_im_poses and im_poses[i] is None:
im_poses[i] = sRT_to_4x4(1, R, T, device)
elif j in done:
if verbose:
print(f" init edge ({i}*,{j}) {score=}")
assert i not in done
i_j = edge_str(i, j)
s, R, T = rigid_points_registration(
edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j]
)
trf = sRT_to_4x4(s, R, T, device)
pts3d[i] = geotrf(trf, edge2pred_i[i_j])
done.add(i)
msp_edges.append((i, j))
if has_im_poses and im_poses[i] is None:
im_poses[i] = sRT_to_4x4(1, R, T, device)
else:
# let's try again later
todo.insert(0, (score, i, j))
if has_im_poses:
# complete all missing informations
pair_scores = list(
sparse_graph.values()
) # already negative scores: less is best
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
np.argsort(pair_scores)
]
for i, j in edges_from_best_to_worse.tolist():
if im_focals[i] is None:
im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)])
for i in range(n_imgs):
if im_poses[i] is None:
msk = im_conf[i] > min_conf_thr
res = fast_pnp(
pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
)
if res:
im_focals[i], im_poses[i] = res
if im_poses[i] is None:
im_poses[i] = torch.eye(4, device=device)
im_poses = torch.stack(im_poses)
else:
im_poses = im_focals = None
return pts3d, msp_edges, im_focals, im_poses
def init_from_pts3d(self, pts3d, im_focals, im_poses):
# init poses
nkp, known_poses_msk, known_poses = self.get_known_poses()
if nkp == 1:
raise NotImplementedError(
"Would be simpler to just align everything afterwards on the single known pose"
)
elif nkp > 1:
# global rigid SE3 alignment
s, R, T = align_multiple_poses(
im_poses[known_poses_msk], known_poses[known_poses_msk]
)
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
# rotate everything
im_poses = trf @ im_poses
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
for img_pts3d in pts3d:
img_pts3d[:] = geotrf(trf, img_pts3d)
else:
pass # no known poses
# set all pairwise poses
for e, (i, j) in enumerate(self.edges):
i_j = edge_str(i, j)
# compute transform that goes from cam to world
s, R, T = rigid_points_registration(
self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
)
self._set_pose(self.pw_poses, e, R, T, scale=s)
# take into account the scale normalization
s_factor = self.get_pw_norm_scale_factor()
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
for img_pts3d in pts3d:
img_pts3d *= s_factor
# init all image poses
if self.has_im_poses:
for i in range(self.n_imgs):
cam2world = im_poses[i]
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
self._set_depthmap(i, depth)
self._set_pose(self.im_poses, i, cam2world)
if im_focals[i] is not None:
if not self.shared_focal:
self._set_focal(i, im_focals[i])
if self.shared_focal:
self._set_focal(0, sum(im_focals) / self.n_imgs)
if self.n_imgs > 2:
self._set_init_depthmap()
if self.verbose:
with torch.no_grad():
print(" init loss =", float(self()))