|
bb = breakpoint |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import wandb |
|
from argparse import ArgumentParser |
|
from datasets.octmae import OctMae |
|
from datasets.foundation_pose import FoundationPose |
|
from datasets.generic_loader import GenericLoader |
|
|
|
from utils.collate import collate |
|
from models.rayquery import RayQuery |
|
from engine import train_epoch, eval_epoch, eval_model |
|
import torch.nn as nn |
|
from models.rayquery import RayQuery, PointmapEncoder, RayEncoder |
|
from models.losses import * |
|
import utils.misc as misc |
|
import os |
|
from utils.viz import just_load_viz |
|
from utils.fusion import fuse_batch |
|
import socket |
|
import time |
|
from utils.augmentations import * |
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--dataset_train", type=str, default="TableOfCubes(size=10,n_views=2,seed=747)") |
|
parser.add_argument("--dataset_test", type=str, default="TableOfCubes(size=10,n_views=2,seed=787)") |
|
parser.add_argument("--dataset_just_load", type=str, default=None) |
|
parser.add_argument("--logdir", type=str, default="logs") |
|
parser.add_argument("--batch_size", type=int, default=5) |
|
parser.add_argument("--n_epochs", type=int, default=100) |
|
parser.add_argument("--n_workers", type=int, default=4) |
|
parser.add_argument("--model", type=str, default="RayQuery(ray_enc=RayEncoder(),pointmap_enc=PointmapEncoder(),criterion=RayCompletion(ConfLoss(L21)))") |
|
parser.add_argument("--save_every", type=int, default=1) |
|
parser.add_argument("--resume", type=str, default=None) |
|
parser.add_argument("--eval_every", type=int, default=3) |
|
parser.add_argument("--wandb_project", type=str, default=None) |
|
parser.add_argument("--wandb_run_name", type=str, default="init") |
|
parser.add_argument("--just_load", action="store_true") |
|
parser.add_argument("--device", type=str, default="cuda") |
|
parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876")) |
|
parser.add_argument("--mesh", action="store_true") |
|
parser.add_argument("--max_norm", type=float, default=-1) |
|
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') |
|
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', |
|
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') |
|
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', |
|
help='lower lr bound for cyclic schedulers that hit 0') |
|
parser.add_argument('--warmup_epochs', type=int, default=10) |
|
parser.add_argument('--weight_decay', type=float, default=0.01) |
|
parser.add_argument('--normalize_mode',type=str,default='None') |
|
parser.add_argument('--start_from',type=str,default=None) |
|
parser.add_argument('--augmentor',type=str,default='None') |
|
return parser.parse_args() |
|
|
|
def main(args): |
|
load_dino = False |
|
if not args.just_load: |
|
dataset_train = eval(args.dataset_train) |
|
dataset_test = eval(args.dataset_test) |
|
if not dataset_train.prefetch_dino: |
|
load_dino = True |
|
rank, world_size, local_rank = misc.setup_distributed() |
|
sampler_train = torch.utils.data.DistributedSampler( |
|
dataset_train, num_replicas=world_size, rank=rank, shuffle=True |
|
) |
|
|
|
sampler_test = torch.utils.data.DistributedSampler( |
|
dataset_test, num_replicas=world_size, rank=rank, shuffle=False |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset_train, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate, |
|
num_workers=args.n_workers, |
|
pin_memory=True, |
|
prefetch_factor=2, |
|
drop_last=True |
|
) |
|
test_loader = DataLoader( |
|
dataset_test, sampler=sampler_test, batch_size=args.batch_size, shuffle=False, collate_fn=collate, |
|
num_workers=args.n_workers, |
|
pin_memory=True, |
|
prefetch_factor=2, |
|
drop_last=True |
|
) |
|
|
|
n_scenes_epoch = len(train_loader) * args.batch_size * world_size |
|
print(f"Number of scenes in epoch: {n_scenes_epoch}") |
|
else: |
|
if args.dataset_just_load is None: |
|
dataset = eval(args.dataset_train) |
|
else: |
|
dataset = eval(args.dataset_just_load) |
|
if not dataset.prefetch_dino: |
|
load_dino = True |
|
rank, world_size, local_rank = misc.setup_distributed() |
|
sampler_train = torch.utils.data.DistributedSampler( |
|
dataset, num_replicas=world_size, rank=rank, shuffle=False |
|
) |
|
just_loader = DataLoader(dataset, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
model = eval(args.model).to(args.device) |
|
if args.augmentor != 'None': |
|
augmentor = eval(args.augmentor) |
|
else: |
|
augmentor = None |
|
|
|
if load_dino and len(model.dino_layers) > 0: |
|
dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg") |
|
dino_model.eval() |
|
dino_model.to("cuda") |
|
else: |
|
dino_model = None |
|
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=True) |
|
model_without_ddp = model.module if hasattr(model, 'module') else model |
|
|
|
eff_batch_size = args.batch_size * misc.get_world_size() |
|
if args.lr is None: |
|
args.lr = args.blr * eff_batch_size / 256 |
|
|
|
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay) |
|
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) |
|
os.makedirs(args.logdir,exist_ok=True) |
|
start_epoch = 0 |
|
print("Running on host %s" % socket.gethostname()) |
|
if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-latest.pth")): |
|
checkpoint = torch.load(os.path.join(args.resume, "checkpoint-latest.pth"), map_location='cpu') |
|
model_without_ddp.load_state_dict(checkpoint['model']) |
|
model_params = list(model.parameters()) |
|
print("Resume checkpoint %s" % args.resume) |
|
|
|
if 'optimizer' in checkpoint and 'epoch' in checkpoint: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
start_epoch = checkpoint['epoch'] + 1 |
|
print("With optim & sched!") |
|
del checkpoint |
|
elif args.start_from is not None: |
|
checkpoint = torch.load(args.start_from, map_location='cpu') |
|
model_without_ddp.load_state_dict(checkpoint['model']) |
|
print("Start from checkpoint %s" % args.start_from) |
|
if args.just_load: |
|
with torch.no_grad(): |
|
while True: |
|
|
|
for data in just_loader: |
|
pred, gt, loss_dict, batch = eval_model(model,data,mode='viz',args=args,dino_model=dino_model,augmentor=augmentor) |
|
|
|
gt = {k: v.float() for k, v in gt.items()} |
|
pred = {k: v.float() for k, v in pred.items()} |
|
|
|
|
|
|
|
print(f"{'Key':<10} {'Value':<10}") |
|
print("-"*20) |
|
for key, value in loss_dict.items(): |
|
print(f"{key:<10}: {value:.4f}") |
|
print("-"*20) |
|
name = args.logdir |
|
addr = args.rr_addr |
|
if args.mesh: |
|
fused_meshes = fuse_batch(pred,gt,data, voxel_size=0.002) |
|
else: |
|
fused_meshes = None |
|
just_load_viz(pred,gt,batch,addr=addr,name=name,fused_meshes=fused_meshes) |
|
breakpoint() |
|
return |
|
else: |
|
if args.wandb_project and misc.get_rank() == 0: |
|
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args) |
|
log_wandb = args.wandb_project |
|
else: |
|
log_wandb = None |
|
for epoch in range(start_epoch,args.n_epochs): |
|
start_time = time.time() |
|
log_dict = train_epoch(model,train_loader,optimizer,device=args.device,max_norm=args.max_norm,epoch=epoch, |
|
log_wandb=log_wandb,batch_size=eff_batch_size,args=args,dino_model=dino_model,augmentor=augmentor) |
|
end_time = time.time() |
|
print(f"Epoch {epoch} train loss: {log_dict['loss']:.4f} grad_norm: {log_dict['grad_norm']:.4f} \n") |
|
print(f"Time taken for epoch {epoch}: {end_time - start_time:.2f} seconds") |
|
|
|
if epoch % args.eval_every == 0: |
|
test_log_dict = eval_epoch(model,test_loader,device=args.device,dino_model=dino_model,args=args,augmentor=augmentor) |
|
print(f"Epoch {epoch} test loss: {test_log_dict['loss']:.4f} \n") |
|
if log_wandb: |
|
wandb_dict = {f"test_{k}":v for k,v in test_log_dict.items()} |
|
wandb.log(wandb_dict, step=(epoch+1)*n_scenes_epoch) |
|
if epoch % args.save_every == 0: |
|
|
|
|
|
misc.save_model(args, epoch, model_without_ddp, optimizer, epoch_name=f"latest") |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |