File size: 6,119 Bytes
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torchvision.transforms as tvf

dino_patch_size = 14

def batch_to_device(batch,device='cuda'):
    for key in batch:
        if isinstance(batch[key],torch.Tensor):
            batch[key] = batch[key].to(device)
        elif isinstance(batch[key],dict):
            batch[key] = batch_to_device(batch[key],device)
    return batch


def compute_pointmap(depth: torch.Tensor, intrinsics: torch.Tensor, cam2world: torch.Tensor = None) -> torch.Tensor:
    fx, fy = intrinsics[0, 0], intrinsics[1, 1]
    cx, cy = intrinsics[0, 2], intrinsics[1, 2]
    h, w = depth.shape
        
    i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
    i = i.to(depth.device)
    j = j.to(depth.device)

    x_cam = (i - cx) * depth / fx
    y_cam = (j - cy) * depth / fy

    points_cam = torch.stack([x_cam, y_cam, depth], axis=-1)

    if cam2world is not None:
        points_cam = torch.matmul(cam2world[:3, :3], points_cam.reshape(-1, 3).T).T + cam2world[:3, 3]
    points_cam = points_cam.reshape(h, w, 3)

    return points_cam

def compute_pointmaps(depths: torch.Tensor, intrinsics: torch.Tensor, cam2worlds: torch.Tensor) -> torch.Tensor:
    pointmaps = []
    depth_shape = depths.shape
    pointmaps_shape = depths.shape + (3,)
    for depth, K, c2w in zip(depths, intrinsics, cam2worlds):
        n_views = depth.shape[0]
        for i in range(n_views):
            pointmaps.append(compute_pointmap(depth[i], K[i],c2w[i]))
    return torch.stack(pointmaps).reshape(pointmaps_shape)

def depth_to_metric(depth):
    # depth: shape H x W
    # we want to convert the depth to a metric depth
    depth_max = 10.0
    depth_scaled = depth_max * (depth / 65535.0)
    
    return depth_scaled

def make_rgb_transform() -> tvf.Compose:
    return tvf.Compose([
        #tvf.ToTensor(),
        #lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
        tvf.Normalize(
            mean=(123.675, 116.28, 103.53),
            std=(58.395, 57.12, 57.375),
        ),
    ])

rgb_transform = make_rgb_transform()

def compute_dino_and_store_features(dino_model : torch.nn.Module, rgb: torch.Tensor, mask: torch.Tensor,dino_layers: list[int] = None) -> torch.Tensor:
    """Computes the DINO features given an RGB image."""
    rgb = rgb.squeeze(1)
    mask = mask.squeeze(1)
    rgb = rgb.permute(0,3,1,2)
    mask = mask.unsqueeze(1).repeat(1,3,1,1)
    rgb = rgb * mask

    rgb = rgb.float()
    H, W = rgb.shape[-2:]
    goal_H, goal_W = H//dino_patch_size*dino_patch_size, W//dino_patch_size*dino_patch_size
    resize_transform = tvf.CenterCrop([goal_H, goal_W])
    with torch.no_grad():
        rgb = resize_transform(rgb)
        rgb = rgb_transform(rgb)
        all_feat = dino_model.get_intermediate_layers(rgb, dino_layers)
        dino_feat = torch.cat(all_feat, dim=-1)
    return dino_feat


def prepare_fast_batch(batch,dino_model = None,dino_layers = None):
    # depth to metric    
    batch['new_cams']['depths'] = depth_to_metric(batch['new_cams']['depths'])
    batch['input_cams']['depths'] = depth_to_metric(batch['input_cams']['depths'])

    # compute pointmaps
    batch['new_cams']['pointmaps'] = compute_pointmaps(batch['new_cams']['depths'],batch['new_cams']['Ks'],batch['new_cams']['c2ws'])
    batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws'])
    
    # compute dino features
    if dino_model is not None and len(dino_layers) > 0:
        batch['input_cams']['dino_features'] = compute_dino_and_store_features(dino_model,batch['input_cams']['imgs'],batch['input_cams']['valid_masks'],dino_layers)
    
    return batch


def normalize_batch(batch,normalize_mode):
    scale_factors = []
    if normalize_mode == 'None':
        pass
    elif normalize_mode == 'median':
        B = batch['input_cams']['valid_masks'].shape[0]
        for b in range(B):
            input_mask = batch['input_cams']['valid_masks'][b]
            depth_median = batch['input_cams']['depths'][b][input_mask].median()
            scale_factor = 1.0 / depth_median
            scale_factors.append(scale_factor)
            batch['input_cams']['depths'][b] = scale_factor * batch['input_cams']['depths'][b]
            batch['input_cams']['pointmaps'][b] = scale_factor * batch['input_cams']['pointmaps'][b]
            batch['input_cams']['c2ws'][b][0,:3,-1] = scale_factor * batch['input_cams']['c2ws'][b][0,:3,-1]

            batch['new_cams']['depths'][b] = scale_factor * batch['new_cams']['depths'][b]
            batch['new_cams']['pointmaps'][b] = scale_factor * batch['new_cams']['pointmaps'][b]
            batch['new_cams']['c2ws'][b][:,:3,-1] = scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]

    return batch, scale_factors

def denormalize_batch(batch,pred,gt,scale_factors):
    B = len(scale_factors)
    n_new_cams = batch['new_cams']['c2ws'].shape[1]
    for b in range(B):
        new_scale_factor = 1.0 / scale_factors[b]
        batch['input_cams']['depths'][b] = new_scale_factor * batch['input_cams']['depths'][b]
        batch['input_cams']['pointmaps'][b] = new_scale_factor * batch['input_cams']['pointmaps'][b]
        batch['input_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['input_cams']['c2ws'][b][:,:3,-1]
        batch['new_cams']['depths'][b] = new_scale_factor * batch['new_cams']['depths'][b]
        batch['new_cams']['pointmaps'][b] = new_scale_factor * batch['new_cams']['pointmaps'][b]
        batch['new_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]

        pred['depths'][b] = new_scale_factor * pred['depths'][b]

        gt['c2ws'][b][:,:3,-1] = new_scale_factor * gt['c2ws'][b][:,:3,-1]
        gt['depths'][b] = new_scale_factor * gt['depths'][b]
        
        gt['pointmaps'][b] = compute_pointmaps(gt['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
        pred['pointmaps'][b] = compute_pointmaps(pred['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
    return batch, pred, gt