# Copyright Niantic 2021. Patent Pending. All rights reserved. # # This software is licensed under the terms of the ManyDepth licence # which allows for non-commercial use only, the full terms of which are made # available in the LICENSE file. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def disp_to_depth(disp, min_depth=0.1, max_depth=100): """Convert network's sigmoid output into depth prediction The formula for this conversion is given in the 'additional considerations' section of the paper. """ min_disp = 1 / max_depth # 0.05 max_disp = 1 / min_depth # 10 scaled_disp = min_disp + (max_disp - min_disp) * disp depth = 1 / scaled_disp return scaled_disp, depth def transformation_from_parameters(axisangle, translation, invert=False): """Convert the network's (axisangle, translation) output into a 4x4 matrix """ R = rot_from_axisangle(axisangle) t = translation.clone() if invert: R = R.transpose(1, 2) t *= -1 T = get_translation_matrix(t) if invert: M = torch.matmul(R, T) else: M = torch.matmul(T, R) return M def get_translation_matrix(translation_vector): """Convert a translation vector into a 4x4 transformation matrix """ T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) t = translation_vector.contiguous().view(-1, 3, 1) T[:, 0, 0] = 1 T[:, 1, 1] = 1 T[:, 2, 2] = 1 T[:, 3, 3] = 1 T[:, :3, 3, None] = t return T def rot_from_axisangle(vec): """Convert an axisangle rotation into a 4x4 transformation matrix (adapted from https://github.com/Wallacoloo/printipi) Input 'vec' has to be Bx1x3 """ angle = torch.norm(vec, 2, 2, True) axis = vec / (angle + 1e-7) ca = torch.cos(angle) sa = torch.sin(angle) C = 1 - ca x = axis[..., 0].unsqueeze(1) y = axis[..., 1].unsqueeze(1) z = axis[..., 2].unsqueeze(1) xs = x * sa ys = y * sa zs = z * sa xC = x * C yC = y * C zC = z * C xyC = x * yC yzC = y * zC zxC = z * xC rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) rot[:, 0, 0] = torch.squeeze(x * xC + ca) rot[:, 0, 1] = torch.squeeze(xyC - zs) rot[:, 0, 2] = torch.squeeze(zxC + ys) rot[:, 1, 0] = torch.squeeze(xyC + zs) rot[:, 1, 1] = torch.squeeze(y * yC + ca) rot[:, 1, 2] = torch.squeeze(yzC - xs) rot[:, 2, 0] = torch.squeeze(zxC - ys) rot[:, 2, 1] = torch.squeeze(yzC + xs) rot[:, 2, 2] = torch.squeeze(z * zC + ca) rot[:, 3, 3] = 1 return rot def normalize(img): return (img - img.min()) / (img.max() - img.min()) def line(img): img = img.unsqueeze(0) if img.shape[1] == 1: q5, q95 = torch.quantile(img.flatten(), q=torch.tensor((0.05, 0.95), device=img.device)) img[img < q5] = q5 img[img > q95] = q95 return normalize(img) elif img.shape[1] == 3: for c in range(3): img[:, c:c+1] = line(img[:, c:c+1]) return img.squeeze()