Spaces:
Runtime error
Runtime error
# Copyright Niantic 2021. Patent Pending. All rights reserved. | |
# | |
# This software is licensed under the terms of the ManyDepth licence | |
# which allows for non-commercial use only, the full terms of which are made | |
# available in the LICENSE file. | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def disp_to_depth(disp, min_depth=0.1, max_depth=100): | |
"""Convert network's sigmoid output into depth prediction | |
The formula for this conversion is given in the 'additional considerations' | |
section of the paper. | |
""" | |
min_disp = 1 / max_depth # 0.05 | |
max_disp = 1 / min_depth # 10 | |
scaled_disp = min_disp + (max_disp - min_disp) * disp | |
depth = 1 / scaled_disp | |
return scaled_disp, depth | |
def transformation_from_parameters(axisangle, translation, invert=False): | |
"""Convert the network's (axisangle, translation) output into a 4x4 matrix | |
""" | |
R = rot_from_axisangle(axisangle) | |
t = translation.clone() | |
if invert: | |
R = R.transpose(1, 2) | |
t *= -1 | |
T = get_translation_matrix(t) | |
if invert: | |
M = torch.matmul(R, T) | |
else: | |
M = torch.matmul(T, R) | |
return M | |
def get_translation_matrix(translation_vector): | |
"""Convert a translation vector into a 4x4 transformation matrix | |
""" | |
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) | |
t = translation_vector.contiguous().view(-1, 3, 1) | |
T[:, 0, 0] = 1 | |
T[:, 1, 1] = 1 | |
T[:, 2, 2] = 1 | |
T[:, 3, 3] = 1 | |
T[:, :3, 3, None] = t | |
return T | |
def rot_from_axisangle(vec): | |
"""Convert an axisangle rotation into a 4x4 transformation matrix | |
(adapted from https://github.com/Wallacoloo/printipi) | |
Input 'vec' has to be Bx1x3 | |
""" | |
angle = torch.norm(vec, 2, 2, True) | |
axis = vec / (angle + 1e-7) | |
ca = torch.cos(angle) | |
sa = torch.sin(angle) | |
C = 1 - ca | |
x = axis[..., 0].unsqueeze(1) | |
y = axis[..., 1].unsqueeze(1) | |
z = axis[..., 2].unsqueeze(1) | |
xs = x * sa | |
ys = y * sa | |
zs = z * sa | |
xC = x * C | |
yC = y * C | |
zC = z * C | |
xyC = x * yC | |
yzC = y * zC | |
zxC = z * xC | |
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) | |
rot[:, 0, 0] = torch.squeeze(x * xC + ca) | |
rot[:, 0, 1] = torch.squeeze(xyC - zs) | |
rot[:, 0, 2] = torch.squeeze(zxC + ys) | |
rot[:, 1, 0] = torch.squeeze(xyC + zs) | |
rot[:, 1, 1] = torch.squeeze(y * yC + ca) | |
rot[:, 1, 2] = torch.squeeze(yzC - xs) | |
rot[:, 2, 0] = torch.squeeze(zxC - ys) | |
rot[:, 2, 1] = torch.squeeze(yzC + xs) | |
rot[:, 2, 2] = torch.squeeze(z * zC + ca) | |
rot[:, 3, 3] = 1 | |
return rot | |
def normalize(img): | |
return (img - img.min()) / (img.max() - img.min()) | |
def line(img): | |
img = img.unsqueeze(0) | |
if img.shape[1] == 1: | |
q5, q95 = torch.quantile(img.flatten(), q=torch.tensor((0.05, 0.95), device=img.device)) | |
img[img < q5] = q5 | |
img[img > q95] = q95 | |
return normalize(img) | |
elif img.shape[1] == 3: | |
for c in range(3): | |
img[:, c:c+1] = line(img[:, c:c+1]) | |
return img.squeeze() | |