|
import torch |
|
import torchvision.transforms as tvf |
|
|
|
dino_patch_size = 14 |
|
|
|
def batch_to_device(batch,device='cuda'): |
|
for key in batch: |
|
if isinstance(batch[key],torch.Tensor): |
|
batch[key] = batch[key].to(device) |
|
elif isinstance(batch[key],dict): |
|
batch[key] = batch_to_device(batch[key],device) |
|
return batch |
|
|
|
|
|
def compute_pointmap(depth: torch.Tensor, intrinsics: torch.Tensor, cam2world: torch.Tensor = None) -> torch.Tensor: |
|
fx, fy = intrinsics[0, 0], intrinsics[1, 1] |
|
cx, cy = intrinsics[0, 2], intrinsics[1, 2] |
|
h, w = depth.shape |
|
|
|
i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy') |
|
i = i.to(depth.device) |
|
j = j.to(depth.device) |
|
|
|
x_cam = (i - cx) * depth / fx |
|
y_cam = (j - cy) * depth / fy |
|
|
|
points_cam = torch.stack([x_cam, y_cam, depth], axis=-1) |
|
|
|
if cam2world is not None: |
|
points_cam = torch.matmul(cam2world[:3, :3], points_cam.reshape(-1, 3).T).T + cam2world[:3, 3] |
|
points_cam = points_cam.reshape(h, w, 3) |
|
|
|
return points_cam |
|
|
|
def compute_pointmaps(depths: torch.Tensor, intrinsics: torch.Tensor, cam2worlds: torch.Tensor) -> torch.Tensor: |
|
pointmaps = [] |
|
depth_shape = depths.shape |
|
pointmaps_shape = depths.shape + (3,) |
|
for depth, K, c2w in zip(depths, intrinsics, cam2worlds): |
|
n_views = depth.shape[0] |
|
for i in range(n_views): |
|
pointmaps.append(compute_pointmap(depth[i], K[i],c2w[i])) |
|
return torch.stack(pointmaps).reshape(pointmaps_shape) |
|
|
|
def depth_to_metric(depth): |
|
|
|
|
|
depth_max = 10.0 |
|
depth_scaled = depth_max * (depth / 65535.0) |
|
|
|
return depth_scaled |
|
|
|
def make_rgb_transform() -> tvf.Compose: |
|
return tvf.Compose([ |
|
|
|
|
|
tvf.Normalize( |
|
mean=(123.675, 116.28, 103.53), |
|
std=(58.395, 57.12, 57.375), |
|
), |
|
]) |
|
|
|
rgb_transform = make_rgb_transform() |
|
|
|
def compute_dino_and_store_features(dino_model : torch.nn.Module, rgb: torch.Tensor, mask: torch.Tensor,dino_layers: list[int] = None) -> torch.Tensor: |
|
"""Computes the DINO features given an RGB image.""" |
|
rgb = rgb.squeeze(1) |
|
mask = mask.squeeze(1) |
|
rgb = rgb.permute(0,3,1,2) |
|
mask = mask.unsqueeze(1).repeat(1,3,1,1) |
|
rgb = rgb * mask |
|
|
|
rgb = rgb.float() |
|
H, W = rgb.shape[-2:] |
|
goal_H, goal_W = H//dino_patch_size*dino_patch_size, W//dino_patch_size*dino_patch_size |
|
resize_transform = tvf.CenterCrop([goal_H, goal_W]) |
|
with torch.no_grad(): |
|
rgb = resize_transform(rgb) |
|
rgb = rgb_transform(rgb) |
|
all_feat = dino_model.get_intermediate_layers(rgb, dino_layers) |
|
dino_feat = torch.cat(all_feat, dim=-1) |
|
return dino_feat |
|
|
|
|
|
def prepare_fast_batch(batch,dino_model = None,dino_layers = None): |
|
|
|
batch['new_cams']['depths'] = depth_to_metric(batch['new_cams']['depths']) |
|
batch['input_cams']['depths'] = depth_to_metric(batch['input_cams']['depths']) |
|
|
|
|
|
batch['new_cams']['pointmaps'] = compute_pointmaps(batch['new_cams']['depths'],batch['new_cams']['Ks'],batch['new_cams']['c2ws']) |
|
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) |
|
|
|
|
|
if dino_model is not None and len(dino_layers) > 0: |
|
batch['input_cams']['dino_features'] = compute_dino_and_store_features(dino_model,batch['input_cams']['imgs'],batch['input_cams']['valid_masks'],dino_layers) |
|
|
|
return batch |
|
|
|
|
|
def normalize_batch(batch,normalize_mode): |
|
scale_factors = [] |
|
if normalize_mode == 'None': |
|
pass |
|
elif normalize_mode == 'median': |
|
B = batch['input_cams']['valid_masks'].shape[0] |
|
for b in range(B): |
|
input_mask = batch['input_cams']['valid_masks'][b] |
|
depth_median = batch['input_cams']['depths'][b][input_mask].median() |
|
scale_factor = 1.0 / depth_median |
|
scale_factors.append(scale_factor) |
|
batch['input_cams']['depths'][b] = scale_factor * batch['input_cams']['depths'][b] |
|
batch['input_cams']['pointmaps'][b] = scale_factor * batch['input_cams']['pointmaps'][b] |
|
batch['input_cams']['c2ws'][b][0,:3,-1] = scale_factor * batch['input_cams']['c2ws'][b][0,:3,-1] |
|
|
|
batch['new_cams']['depths'][b] = scale_factor * batch['new_cams']['depths'][b] |
|
batch['new_cams']['pointmaps'][b] = scale_factor * batch['new_cams']['pointmaps'][b] |
|
batch['new_cams']['c2ws'][b][:,:3,-1] = scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1] |
|
|
|
return batch, scale_factors |
|
|
|
def denormalize_batch(batch,pred,gt,scale_factors): |
|
B = len(scale_factors) |
|
n_new_cams = batch['new_cams']['c2ws'].shape[1] |
|
for b in range(B): |
|
new_scale_factor = 1.0 / scale_factors[b] |
|
batch['input_cams']['depths'][b] = new_scale_factor * batch['input_cams']['depths'][b] |
|
batch['input_cams']['pointmaps'][b] = new_scale_factor * batch['input_cams']['pointmaps'][b] |
|
batch['input_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['input_cams']['c2ws'][b][:,:3,-1] |
|
batch['new_cams']['depths'][b] = new_scale_factor * batch['new_cams']['depths'][b] |
|
batch['new_cams']['pointmaps'][b] = new_scale_factor * batch['new_cams']['pointmaps'][b] |
|
batch['new_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1] |
|
|
|
pred['depths'][b] = new_scale_factor * pred['depths'][b] |
|
|
|
gt['c2ws'][b][:,:3,-1] = new_scale_factor * gt['c2ws'][b][:,:3,-1] |
|
gt['depths'][b] = new_scale_factor * gt['depths'][b] |
|
|
|
gt['pointmaps'][b] = compute_pointmaps(gt['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1) |
|
pred['pointmaps'][b] = compute_pointmaps(pred['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1) |
|
return batch, pred, gt |
|
|