|
import sys |
|
import socket |
|
import os |
|
|
|
current_dir = os.getcwd() |
|
sys.path.append(current_dir) |
|
from xps.util import * |
|
|
|
root_log_dir = "logs" |
|
n_views = 2 |
|
dataset_size = -1 |
|
|
|
imshape_input = (480,640) |
|
imshape_output = (480,640) |
|
render_size = (480,640) |
|
|
|
preload_train = False |
|
data_dirs = ["/home/jovyan/shared/bduister/data/processed/","/home/jovyan/shared/bduister/data-2/processed/"] |
|
dino_features = [4,11,17,23] |
|
datasets = ['fp_gso','octmae'] |
|
prefetch_dino = False |
|
normalize_mode = 'median' |
|
|
|
start_from = None |
|
|
|
noise_std = 0.005 |
|
view_select_mode = "new_zoom" |
|
rendered_views_mode = "always" |
|
dataset_train = f"GenericLoader(size={dataset_size},seed=747,dir={repr(data_dirs)},split='train',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \ |
|
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')" |
|
dataset_test = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \ |
|
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')" |
|
dataset_just_load = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \ |
|
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')" |
|
|
|
augmentor = "Augmentor()" |
|
|
|
patch_size = 16 |
|
save_every = 1 |
|
|
|
vit="base" |
|
if vit == "debug": |
|
enc_dim = 128 |
|
dec_dim = 128 |
|
n_heads = 4 |
|
enc_depth = 4 |
|
dec_depth = 4 |
|
head_n_layers = 1 |
|
head_dim = 128 |
|
lr = 3e-4 |
|
batch_size = 20 |
|
blr = 1.5e-4 |
|
elif vit == "debug_2": |
|
enc_dim = 512 |
|
dec_dim = 512 |
|
n_heads = 4 |
|
enc_depth = 4 |
|
dec_depth = 10 |
|
head_n_layers = 1 |
|
head_dim = 128 |
|
blr = 1.5e-4 |
|
batch_size = 18 |
|
elif vit == "small": |
|
enc_dim = 384 |
|
dec_dim = 384 |
|
n_heads = 6 |
|
enc_depth = 12 |
|
dec_depth = 12 |
|
head_n_layers = 1 |
|
head_dim = 128 |
|
batch_size = 6 |
|
blr = 1.5e-4 |
|
elif vit == "base": |
|
enc_dim = 768 |
|
dec_dim = 768 |
|
n_heads = 12 |
|
enc_depth = 4 |
|
dec_depth = 12 |
|
head_n_layers = 1 |
|
head_dim = 128 |
|
batch_size = 10 |
|
blr = 1.5e-4 |
|
|
|
lambda_classifier = 0.1 |
|
for skip_conf_points in [False]: |
|
skip_conf_mask = True |
|
model = f"RayQuery(ray_enc=RayEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \ |
|
f"pointmap_enc=PointmapEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \ |
|
f"dino_layers={dino_features}," + \ |
|
f"pts_head_type='dpt_depth'," + \ |
|
f"classifier_head_type='dpt_mask'," + \ |
|
f"decoder_dim={dec_dim},decoder_depth={dec_depth},decoder_num_heads={n_heads},imshape={render_size}," + \ |
|
f"criterion=DepthCompletion(ConfLoss(L21,skip_conf={skip_conf_points}),ConfLoss(ClassifierLoss(BCELoss()),skip_conf={skip_conf_mask}),lambda_classifier={lambda_classifier}),return_all_blocks=True)" |
|
|
|
key = f"conf_points_{skip_conf_points==False}" |
|
key = gen_key(key) |
|
logdir = os.path.join(root_log_dir,key) |
|
resume=logdir |
|
wandb_run_name=key |
|
os.makedirs(logdir,exist_ok=True) |
|
|
|
n_epochs = 20 |
|
eval_every = 1 |
|
max_norm = -1 |
|
OMP_NUM_THREADS=16 |
|
warmup_epochs = 1 |
|
|
|
executable = f"OMP_NUM_THREADS={OMP_NUM_THREADS} torchrun --nnodes 1 --nproc_per_node $(python -c 'import torch; print(torch.cuda.device_count())') --master_port $((RANDOM%500+29000)) main.py" |
|
|
|
if '--just_load' in sys.argv: |
|
batch_size = 5 |
|
command = f"{executable} --{dataset_train=} --{dataset_test=} --{dataset_just_load=} --{logdir=} --{resume=} --{model=} --{batch_size=} --{normalize_mode=} --{augmentor=}" |
|
else: |
|
command = f"{executable} --{dataset_train=} --{dataset_test=} --{logdir=} --{n_epochs=} --{resume=} --{normalize_mode=} --{augmentor=} --{warmup_epochs=}" |
|
command += f" --{model=} --{eval_every=} --{batch_size=} --{save_every=} --{max_norm=}" |
|
command += f" --{blr=}" |
|
if start_from is not None: |
|
command += f" --{start_from=}" |
|
if not '--no_wandb' in sys.argv: |
|
command += f" --wandb_project=3dcomplete " + \ |
|
f"--{wandb_run_name=}" |
|
|
|
if len(sys.argv) > 1: |
|
for arg in sys.argv[1:]: |
|
if not '--no_wandb' in arg: |
|
command += f" {arg}" |
|
print(command) |
|
|