File size: 2,934 Bytes
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import numpy as np

def to_tensor(x,dtype=torch.float64):
    if isinstance(x, torch.Tensor):
        return x.to(dtype)
    elif isinstance(x, np.ndarray):
        return torch.from_numpy(x.copy()).to(dtype)
    else:
        raise ValueError(f"Unsupported type: {type(x)}")

def to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    elif isinstance(x, np.ndarray):
        return x
    else:
        raise ValueError(f"Unsupported type: {type(x)}")
    
def invalid_to_nans( arr, valid_mask, ndim=999 ):
    if valid_mask is not None:
        arr = arr.clone()
        arr[~valid_mask] = float('nan')
    if arr.ndim > ndim:
        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
    return arr

def invalid_to_zeros( arr, valid_mask, ndim=999 ):
    if valid_mask is not None:
        arr = arr.clone()
        arr[~valid_mask] = 0
        nnz = valid_mask.view(len(valid_mask), -1).sum(1)
    else:
        nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
    if arr.ndim > ndim:
        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
    return arr, nnz

def scenes_to_batch(scenes,repeat=None):
    batch = {}
    n_cams = None
    
    if 'new_cams' in scenes:
        n_cams = scenes['new_cams']['depths'].shape[1]
        batch['new_cams'], n_cams = scenes_to_batch(scenes['new_cams'])
        batch['input_cams'],_ = scenes_to_batch(scenes['input_cams'],repeat=n_cams)
    else:
        for key in scenes.keys():
            shape = scenes[key].shape
            if len(shape) > 3 :
                n_cams = shape[1]
                if repeat is not None:
                    # repeat the 2nd dimension by repeat times to also have the inputs repeated in the batch
                    repeat_dims = (1,) * len(shape)  # (1,1,1,...) for all dimensions
                    repeat_dims = list(repeat_dims)
                    repeat_dims[1] = repeat 
                    batch[key] = scenes[key].repeat(*repeat_dims)
                    batch[key] = batch[key].reshape(-1, *shape[2:])
                else:
                    batch[key] = scenes[key].reshape(-1, *shape[2:])
            elif key == 'dino_features':
                repeat_shape = (repeat,) + (1,) * (len(shape) - 1)
                batch[key] = scenes[key].repeat(*repeat_shape)
            else:
                batch[key] = scenes[key]
    return batch, n_cams

def dict_to_scenes(input_dict,n_cams):
    scenes = {}
    for key in input_dict.keys():
        if isinstance(input_dict[key],dict):
            scenes[key] = dict_to_scenes(input_dict[key],n_cams)
        else:
            scenes[key] = input_dict[key].reshape(-1, n_cams, *input_dict[key].shape[1:])
    return scenes

def batch_to_scenes(pred,gt,batch,n_cams):
    # pred
    batch = dict_to_scenes(batch,n_cams)
    pred = dict_to_scenes(pred,n_cams)
    gt = dict_to_scenes(gt,n_cams)
    return pred, gt, batch