Spaces:
Running
on
L4
Running
on
L4
from copy import deepcopy | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import roma | |
from copy import deepcopy | |
import tqdm | |
import os | |
import matplotlib.pyplot as plt | |
from cloud_opt.utils import * | |
from cloud_opt.utils import _check_edges, _compute_img_conf | |
import cloud_opt.init_all as init_fun | |
class BaseOptimizer(nn.Module): | |
"""Optimize a global scene, given a graph-organized observations. | |
Graph node: images | |
Graph edges: observations = (pred1, pred2), pred2 is in pred1's coordinate | |
""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def _init_from_views( | |
self, | |
view1s, | |
view2s, | |
pred1s, | |
pred2s, # whatever predictions, they should be organized into pairwise for graph optimization | |
dist="l1", | |
conf="log", | |
min_conf_thr=3, | |
thr_for_init_conf=False, | |
base_scale=0.5, | |
allow_pw_adaptors=False, | |
pw_break=20, | |
rand_pose=torch.randn, | |
empty_cache=False, | |
verbose=True, | |
): | |
super().__init__() | |
self.edges = [ | |
(int(view1["idx"]), int(view2["idx"])) | |
for view1, view2 in zip(view1s, view2s) | |
] | |
self.dist = ALL_DISTS[dist] | |
self.n_imgs = _check_edges(self.edges) | |
self.edge2pts_i = NoGradParamDict( | |
{ij: pred1s[n]["pts3d_is_self_view"] for n, ij in enumerate(self.str_edges)} | |
) # ij: the name of the edge | |
self.edge2pts_j = NoGradParamDict( | |
{ | |
ij: pred2s[n]["pts3d_in_other_view"] | |
for n, ij in enumerate(self.str_edges) | |
} | |
) | |
self.edge2conf_i = NoGradParamDict( | |
{ij: pred1s[n]["conf_self"] for n, ij in enumerate(self.str_edges)} | |
) | |
self.edge2conf_j = NoGradParamDict( | |
{ij: pred2s[n]["conf"] for n, ij in enumerate(self.str_edges)} | |
) | |
self.imshapes = get_imshapes(self.edges, pred1s, pred2s) | |
self.min_conf_thr = min_conf_thr | |
self.thr_for_init_conf = thr_for_init_conf | |
self.conf_trf = get_conf_trf(conf) | |
self.im_conf = _compute_img_conf( | |
self.imshapes, self.device, self.edges, self.edge2conf_i, self.edge2conf_j | |
) | |
for i in range(len(self.im_conf)): | |
self.im_conf[i].requires_grad = False | |
self.init_conf_maps = [c.clone() for c in self.im_conf] | |
self.base_scale = base_scale | |
self.norm_pw_scale = True | |
self.pw_break = pw_break | |
self.POSE_DIM = 7 | |
self.pw_poses = nn.Parameter( | |
rand_pose((self.n_edges, 1 + self.POSE_DIM)) | |
) # pairwise poses | |
self.pw_adaptors = nn.Parameter( | |
torch.zeros((self.n_edges, 2)) | |
) # slight xy/z adaptation | |
self.pw_adaptors.requires_grad_(allow_pw_adaptors) | |
self.has_im_poses = False | |
self.rand_pose = rand_pose | |
def get_known_poses(self): | |
if self.has_im_poses: | |
known_poses_msk = torch.tensor( | |
[not (p.requires_grad) for p in self.im_poses] | |
) | |
known_poses = self.get_im_poses() | |
return known_poses_msk.sum(), known_poses_msk, known_poses | |
else: | |
return 0, None, None | |
def get_pw_norm_scale_factor(self): | |
if self.norm_pw_scale: | |
# normalize scales so that things cannot go south | |
# we want that exp(scale) ~= self.base_scale | |
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() | |
else: | |
return 1 # don't norm scale for known poses | |
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): | |
# all poses == cam-to-world | |
pose = poses[idx] | |
if not (pose.requires_grad or force): | |
return pose | |
if R.shape == (4, 4): | |
assert T is None | |
T = R[:3, 3] | |
R = R[:3, :3] | |
if R is not None: | |
pose.data[0:4] = roma.rotmat_to_unitquat(R) | |
if T is not None: | |
pose.data[4:7] = signed_log1p( | |
T / (scale or 1) | |
) # translation is function of scale | |
if scale is not None: | |
assert poses.shape[-1] in (8, 13) | |
pose.data[-1] = np.log(float(scale)) | |
return pose | |
def forward(self, ret_details=False): | |
pw_poses = self.get_pw_poses() # cam-to-world | |
pw_adapt = self.get_adaptors() | |
proj_pts3d = self.get_pts3d() | |
# pre-compute pixel weights | |
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} | |
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} | |
loss = 0 | |
if ret_details: | |
details = -torch.ones((self.n_imgs, self.n_imgs)) | |
for e, (i, j) in enumerate(self.edges): | |
i_j = edge_str(i, j) | |
# distance in image i and j | |
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) | |
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) | |
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() | |
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() | |
loss = loss + li + lj | |
if ret_details: | |
details[i, j] = li + lj | |
loss /= self.n_edges # average over all pairs | |
if ret_details: | |
return loss, details | |
return loss | |
def compute_global_alignment(self, init=None, niter_PnP=10, **kw): | |
if init is None: | |
pass | |
elif init == "msp" or init == "mst": | |
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) | |
elif init == "known_poses": | |
raise NotImplementedError | |
self.preset_pose(known_poses=self.camera_poses, requires_grad=True) | |
init_fun.init_from_known_poses( | |
self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP | |
) | |
else: | |
raise ValueError(f"bad value for {init=}") | |
return global_alignment_loop(self, **kw) | |
def str_edges(self): | |
return [edge_str(i, j) for i, j in self.edges] | |
def n_edges(self): | |
return len(self.edges) | |
def global_alignment_loop( | |
net, | |
lr=0.01, | |
niter=300, | |
schedule="cosine", | |
lr_min=1e-3, | |
temporal_smoothing_weight=0, | |
depth_map_save_dir=None, | |
): | |
params = [p for p in net.parameters() if p.requires_grad] | |
if not params: | |
return net | |
verbose = net.verbose | |
if verbose: | |
print("Global alignement - optimizing for:") | |
print([name for name, value in net.named_parameters() if value.requires_grad]) | |
lr_base = lr | |
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) | |
loss = float("inf") | |
if verbose: | |
with tqdm.tqdm(total=niter) as bar: | |
while bar.n < bar.total: | |
if bar.n % 500 == 0 and depth_map_save_dir is not None: | |
if not os.path.exists(depth_map_save_dir): | |
os.makedirs(depth_map_save_dir) | |
# visualize the depthmaps | |
depth_maps = net.get_depthmaps() | |
for i, depth_map in enumerate(depth_maps): | |
depth_map_save_path = os.path.join( | |
depth_map_save_dir, f"depthmaps_{i}_iter_{bar.n}.png" | |
) | |
plt.imsave( | |
depth_map_save_path, | |
depth_map.detach().cpu().numpy(), | |
cmap="jet", | |
) | |
print( | |
f"Saved depthmaps at iteration {bar.n} to {depth_map_save_dir}" | |
) | |
loss, lr = global_alignment_iter( | |
net, | |
bar.n, | |
niter, | |
lr_base, | |
lr_min, | |
optimizer, | |
schedule, | |
temporal_smoothing_weight=temporal_smoothing_weight, | |
) | |
bar.set_postfix_str(f"{lr=:g} loss={loss:g}") | |
bar.update() | |
else: | |
for n in range(niter): | |
loss, _ = global_alignment_iter( | |
net, | |
n, | |
niter, | |
lr_base, | |
lr_min, | |
optimizer, | |
schedule, | |
temporal_smoothing_weight=temporal_smoothing_weight, | |
) | |
return loss | |
def global_alignment_iter( | |
net, | |
cur_iter, | |
niter, | |
lr_base, | |
lr_min, | |
optimizer, | |
schedule, | |
temporal_smoothing_weight=0, | |
): | |
t = cur_iter / niter | |
if schedule == "cosine": | |
lr = cosine_schedule(t, lr_base, lr_min) | |
elif schedule == "linear": | |
lr = linear_schedule(t, lr_base, lr_min) | |
elif schedule.startswith("cycle"): | |
try: | |
num_cycles = int(schedule[5:]) | |
except ValueError: | |
num_cycles = 2 | |
lr = cycled_linear_schedule(t, lr_base, lr_min, num_cycles=num_cycles) | |
else: | |
raise ValueError(f"bad lr {schedule=}") | |
adjust_learning_rate_by_lr(optimizer, lr) | |
optimizer.zero_grad() | |
if net.empty_cache: | |
torch.cuda.empty_cache() | |
loss = net(epoch=cur_iter) | |
if net.empty_cache: | |
torch.cuda.empty_cache() | |
loss.backward() | |
if net.empty_cache: | |
torch.cuda.empty_cache() | |
optimizer.step() | |
return float(loss), lr | |