rayst3r / utils /batch_prep.py
bartduis's picture
init
70d1188
raw
history blame
6.12 kB
import torch
import torchvision.transforms as tvf
dino_patch_size = 14
def batch_to_device(batch,device='cuda'):
for key in batch:
if isinstance(batch[key],torch.Tensor):
batch[key] = batch[key].to(device)
elif isinstance(batch[key],dict):
batch[key] = batch_to_device(batch[key],device)
return batch
def compute_pointmap(depth: torch.Tensor, intrinsics: torch.Tensor, cam2world: torch.Tensor = None) -> torch.Tensor:
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
h, w = depth.shape
i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
i = i.to(depth.device)
j = j.to(depth.device)
x_cam = (i - cx) * depth / fx
y_cam = (j - cy) * depth / fy
points_cam = torch.stack([x_cam, y_cam, depth], axis=-1)
if cam2world is not None:
points_cam = torch.matmul(cam2world[:3, :3], points_cam.reshape(-1, 3).T).T + cam2world[:3, 3]
points_cam = points_cam.reshape(h, w, 3)
return points_cam
def compute_pointmaps(depths: torch.Tensor, intrinsics: torch.Tensor, cam2worlds: torch.Tensor) -> torch.Tensor:
pointmaps = []
depth_shape = depths.shape
pointmaps_shape = depths.shape + (3,)
for depth, K, c2w in zip(depths, intrinsics, cam2worlds):
n_views = depth.shape[0]
for i in range(n_views):
pointmaps.append(compute_pointmap(depth[i], K[i],c2w[i]))
return torch.stack(pointmaps).reshape(pointmaps_shape)
def depth_to_metric(depth):
# depth: shape H x W
# we want to convert the depth to a metric depth
depth_max = 10.0
depth_scaled = depth_max * (depth / 65535.0)
return depth_scaled
def make_rgb_transform() -> tvf.Compose:
return tvf.Compose([
#tvf.ToTensor(),
#lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
tvf.Normalize(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
),
])
rgb_transform = make_rgb_transform()
def compute_dino_and_store_features(dino_model : torch.nn.Module, rgb: torch.Tensor, mask: torch.Tensor,dino_layers: list[int] = None) -> torch.Tensor:
"""Computes the DINO features given an RGB image."""
rgb = rgb.squeeze(1)
mask = mask.squeeze(1)
rgb = rgb.permute(0,3,1,2)
mask = mask.unsqueeze(1).repeat(1,3,1,1)
rgb = rgb * mask
rgb = rgb.float()
H, W = rgb.shape[-2:]
goal_H, goal_W = H//dino_patch_size*dino_patch_size, W//dino_patch_size*dino_patch_size
resize_transform = tvf.CenterCrop([goal_H, goal_W])
with torch.no_grad():
rgb = resize_transform(rgb)
rgb = rgb_transform(rgb)
all_feat = dino_model.get_intermediate_layers(rgb, dino_layers)
dino_feat = torch.cat(all_feat, dim=-1)
return dino_feat
def prepare_fast_batch(batch,dino_model = None,dino_layers = None):
# depth to metric
batch['new_cams']['depths'] = depth_to_metric(batch['new_cams']['depths'])
batch['input_cams']['depths'] = depth_to_metric(batch['input_cams']['depths'])
# compute pointmaps
batch['new_cams']['pointmaps'] = compute_pointmaps(batch['new_cams']['depths'],batch['new_cams']['Ks'],batch['new_cams']['c2ws'])
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws'])
# compute dino features
if dino_model is not None and len(dino_layers) > 0:
batch['input_cams']['dino_features'] = compute_dino_and_store_features(dino_model,batch['input_cams']['imgs'],batch['input_cams']['valid_masks'],dino_layers)
return batch
def normalize_batch(batch,normalize_mode):
scale_factors = []
if normalize_mode == 'None':
pass
elif normalize_mode == 'median':
B = batch['input_cams']['valid_masks'].shape[0]
for b in range(B):
input_mask = batch['input_cams']['valid_masks'][b]
depth_median = batch['input_cams']['depths'][b][input_mask].median()
scale_factor = 1.0 / depth_median
scale_factors.append(scale_factor)
batch['input_cams']['depths'][b] = scale_factor * batch['input_cams']['depths'][b]
batch['input_cams']['pointmaps'][b] = scale_factor * batch['input_cams']['pointmaps'][b]
batch['input_cams']['c2ws'][b][0,:3,-1] = scale_factor * batch['input_cams']['c2ws'][b][0,:3,-1]
batch['new_cams']['depths'][b] = scale_factor * batch['new_cams']['depths'][b]
batch['new_cams']['pointmaps'][b] = scale_factor * batch['new_cams']['pointmaps'][b]
batch['new_cams']['c2ws'][b][:,:3,-1] = scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
return batch, scale_factors
def denormalize_batch(batch,pred,gt,scale_factors):
B = len(scale_factors)
n_new_cams = batch['new_cams']['c2ws'].shape[1]
for b in range(B):
new_scale_factor = 1.0 / scale_factors[b]
batch['input_cams']['depths'][b] = new_scale_factor * batch['input_cams']['depths'][b]
batch['input_cams']['pointmaps'][b] = new_scale_factor * batch['input_cams']['pointmaps'][b]
batch['input_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['input_cams']['c2ws'][b][:,:3,-1]
batch['new_cams']['depths'][b] = new_scale_factor * batch['new_cams']['depths'][b]
batch['new_cams']['pointmaps'][b] = new_scale_factor * batch['new_cams']['pointmaps'][b]
batch['new_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
pred['depths'][b] = new_scale_factor * pred['depths'][b]
gt['c2ws'][b][:,:3,-1] = new_scale_factor * gt['c2ws'][b][:,:3,-1]
gt['depths'][b] = new_scale_factor * gt['depths'][b]
gt['pointmaps'][b] = compute_pointmaps(gt['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
pred['pointmaps'][b] = compute_pointmaps(pred['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
return batch, pred, gt