File size: 4,714 Bytes
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import sys
import socket
import os
# Add the current working directory to the Python path
current_dir = os.getcwd()
sys.path.append(current_dir)
from xps.util import *

root_log_dir = "logs"
n_views = 2 
dataset_size = -1

imshape_input = (480,640)
imshape_output = (480,640)
render_size = (480,640)

preload_train = False
data_dirs = ["/home/jovyan/shared/bduister/data/processed/","/home/jovyan/shared/bduister/data-2/processed/"]
dino_features = [4,11,17,23]
datasets = ['fp_gso','octmae']
prefetch_dino = False
normalize_mode = 'median'
#start_from = "checkpoints/gso_conf.pth"
start_from = None

noise_std = 0.005
view_select_mode = "new_zoom"
rendered_views_mode = "always"
dataset_train = f"GenericLoader(size={dataset_size},seed=747,dir={repr(data_dirs)},split='train',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
dataset_test = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
dataset_just_load = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"

augmentor = "Augmentor()"

patch_size = 16
save_every = 1

vit="base"
if vit == "debug":
    enc_dim = 128
    dec_dim = 128
    n_heads = 4
    enc_depth = 4 
    dec_depth = 4
    head_n_layers = 1
    head_dim = 128
    lr = 3e-4
    batch_size = 20
    blr = 1.5e-4
elif vit == "debug_2":
    enc_dim = 512
    dec_dim = 512
    n_heads = 4
    enc_depth = 4
    dec_depth = 10
    head_n_layers = 1
    head_dim = 128
    blr = 1.5e-4
    batch_size = 18
elif vit == "small":
    enc_dim = 384
    dec_dim = 384
    n_heads = 6
    enc_depth = 12 
    dec_depth = 12
    head_n_layers = 1
    head_dim = 128
    batch_size = 6
    blr = 1.5e-4
elif vit == "base":
    enc_dim = 768
    dec_dim = 768
    n_heads = 12
    enc_depth = 4 
    dec_depth = 12
    head_n_layers = 1
    head_dim = 128
    batch_size = 10
    blr = 1.5e-4

lambda_classifier = 0.1
for skip_conf_points in [False]:
    skip_conf_mask = True
    model = f"RayQuery(ray_enc=RayEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \
            f"pointmap_enc=PointmapEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \
            f"dino_layers={dino_features}," + \
            f"pts_head_type='dpt_depth'," + \
            f"classifier_head_type='dpt_mask'," + \
            f"decoder_dim={dec_dim},decoder_depth={dec_depth},decoder_num_heads={n_heads},imshape={render_size}," + \
            f"criterion=DepthCompletion(ConfLoss(L21,skip_conf={skip_conf_points}),ConfLoss(ClassifierLoss(BCELoss()),skip_conf={skip_conf_mask}),lambda_classifier={lambda_classifier}),return_all_blocks=True)"

    key = f"conf_points_{skip_conf_points==False}"
    key = gen_key(key)
    logdir = os.path.join(root_log_dir,key)
    resume=logdir
    wandb_run_name=key
    os.makedirs(logdir,exist_ok=True)

    n_epochs = 20
    eval_every = 1
    max_norm = -1
    OMP_NUM_THREADS=16
    warmup_epochs = 1
    
    executable = f"OMP_NUM_THREADS={OMP_NUM_THREADS} torchrun --nnodes 1 --nproc_per_node $(python -c 'import torch; print(torch.cuda.device_count())') --master_port $((RANDOM%500+29000)) main.py"
    #executable = f"python main.py"
    if '--just_load' in sys.argv:
        batch_size = 5
        command = f"{executable} --{dataset_train=} --{dataset_test=} --{dataset_just_load=} --{logdir=} --{resume=} --{model=} --{batch_size=} --{normalize_mode=} --{augmentor=}"
    else:
        command = f"{executable} --{dataset_train=} --{dataset_test=} --{logdir=} --{n_epochs=} --{resume=} --{normalize_mode=} --{augmentor=} --{warmup_epochs=}" 
        command += f" --{model=} --{eval_every=} --{batch_size=} --{save_every=} --{max_norm=}"
        command += f" --{blr=}"
        if start_from is not None:
            command += f" --{start_from=}"
        if not '--no_wandb' in sys.argv:
            command += f" --wandb_project=3dcomplete " + \
                    f"--{wandb_run_name=}"

    if len(sys.argv) > 1:
        for arg in sys.argv[1:]:
            if not '--no_wandb' in arg:
                command += f" {arg}"
    print(command)