File size: 6,119 Bytes
70d1188 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torchvision.transforms as tvf
dino_patch_size = 14
def batch_to_device(batch,device='cuda'):
for key in batch:
if isinstance(batch[key],torch.Tensor):
batch[key] = batch[key].to(device)
elif isinstance(batch[key],dict):
batch[key] = batch_to_device(batch[key],device)
return batch
def compute_pointmap(depth: torch.Tensor, intrinsics: torch.Tensor, cam2world: torch.Tensor = None) -> torch.Tensor:
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
h, w = depth.shape
i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
i = i.to(depth.device)
j = j.to(depth.device)
x_cam = (i - cx) * depth / fx
y_cam = (j - cy) * depth / fy
points_cam = torch.stack([x_cam, y_cam, depth], axis=-1)
if cam2world is not None:
points_cam = torch.matmul(cam2world[:3, :3], points_cam.reshape(-1, 3).T).T + cam2world[:3, 3]
points_cam = points_cam.reshape(h, w, 3)
return points_cam
def compute_pointmaps(depths: torch.Tensor, intrinsics: torch.Tensor, cam2worlds: torch.Tensor) -> torch.Tensor:
pointmaps = []
depth_shape = depths.shape
pointmaps_shape = depths.shape + (3,)
for depth, K, c2w in zip(depths, intrinsics, cam2worlds):
n_views = depth.shape[0]
for i in range(n_views):
pointmaps.append(compute_pointmap(depth[i], K[i],c2w[i]))
return torch.stack(pointmaps).reshape(pointmaps_shape)
def depth_to_metric(depth):
# depth: shape H x W
# we want to convert the depth to a metric depth
depth_max = 10.0
depth_scaled = depth_max * (depth / 65535.0)
return depth_scaled
def make_rgb_transform() -> tvf.Compose:
return tvf.Compose([
#tvf.ToTensor(),
#lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
tvf.Normalize(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
),
])
rgb_transform = make_rgb_transform()
def compute_dino_and_store_features(dino_model : torch.nn.Module, rgb: torch.Tensor, mask: torch.Tensor,dino_layers: list[int] = None) -> torch.Tensor:
"""Computes the DINO features given an RGB image."""
rgb = rgb.squeeze(1)
mask = mask.squeeze(1)
rgb = rgb.permute(0,3,1,2)
mask = mask.unsqueeze(1).repeat(1,3,1,1)
rgb = rgb * mask
rgb = rgb.float()
H, W = rgb.shape[-2:]
goal_H, goal_W = H//dino_patch_size*dino_patch_size, W//dino_patch_size*dino_patch_size
resize_transform = tvf.CenterCrop([goal_H, goal_W])
with torch.no_grad():
rgb = resize_transform(rgb)
rgb = rgb_transform(rgb)
all_feat = dino_model.get_intermediate_layers(rgb, dino_layers)
dino_feat = torch.cat(all_feat, dim=-1)
return dino_feat
def prepare_fast_batch(batch,dino_model = None,dino_layers = None):
# depth to metric
batch['new_cams']['depths'] = depth_to_metric(batch['new_cams']['depths'])
batch['input_cams']['depths'] = depth_to_metric(batch['input_cams']['depths'])
# compute pointmaps
batch['new_cams']['pointmaps'] = compute_pointmaps(batch['new_cams']['depths'],batch['new_cams']['Ks'],batch['new_cams']['c2ws'])
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws'])
# compute dino features
if dino_model is not None and len(dino_layers) > 0:
batch['input_cams']['dino_features'] = compute_dino_and_store_features(dino_model,batch['input_cams']['imgs'],batch['input_cams']['valid_masks'],dino_layers)
return batch
def normalize_batch(batch,normalize_mode):
scale_factors = []
if normalize_mode == 'None':
pass
elif normalize_mode == 'median':
B = batch['input_cams']['valid_masks'].shape[0]
for b in range(B):
input_mask = batch['input_cams']['valid_masks'][b]
depth_median = batch['input_cams']['depths'][b][input_mask].median()
scale_factor = 1.0 / depth_median
scale_factors.append(scale_factor)
batch['input_cams']['depths'][b] = scale_factor * batch['input_cams']['depths'][b]
batch['input_cams']['pointmaps'][b] = scale_factor * batch['input_cams']['pointmaps'][b]
batch['input_cams']['c2ws'][b][0,:3,-1] = scale_factor * batch['input_cams']['c2ws'][b][0,:3,-1]
batch['new_cams']['depths'][b] = scale_factor * batch['new_cams']['depths'][b]
batch['new_cams']['pointmaps'][b] = scale_factor * batch['new_cams']['pointmaps'][b]
batch['new_cams']['c2ws'][b][:,:3,-1] = scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
return batch, scale_factors
def denormalize_batch(batch,pred,gt,scale_factors):
B = len(scale_factors)
n_new_cams = batch['new_cams']['c2ws'].shape[1]
for b in range(B):
new_scale_factor = 1.0 / scale_factors[b]
batch['input_cams']['depths'][b] = new_scale_factor * batch['input_cams']['depths'][b]
batch['input_cams']['pointmaps'][b] = new_scale_factor * batch['input_cams']['pointmaps'][b]
batch['input_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['input_cams']['c2ws'][b][:,:3,-1]
batch['new_cams']['depths'][b] = new_scale_factor * batch['new_cams']['depths'][b]
batch['new_cams']['pointmaps'][b] = new_scale_factor * batch['new_cams']['pointmaps'][b]
batch['new_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
pred['depths'][b] = new_scale_factor * pred['depths'][b]
gt['c2ws'][b][:,:3,-1] = new_scale_factor * gt['c2ws'][b][:,:3,-1]
gt['depths'][b] = new_scale_factor * gt['depths'][b]
gt['pointmaps'][b] = compute_pointmaps(gt['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
pred['pointmaps'][b] = compute_pointmaps(pred['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
return batch, pred, gt
|