rayst3r / main.py
bartduis's picture
init
70d1188
raw
history blame
9.73 kB
bb = breakpoint
import torch
from torch.utils.data import DataLoader
import wandb
from argparse import ArgumentParser
from datasets.octmae import OctMae
from datasets.foundation_pose import FoundationPose
from datasets.generic_loader import GenericLoader
from utils.collate import collate
from models.rayquery import RayQuery
from engine import train_epoch, eval_epoch, eval_model
import torch.nn as nn
from models.rayquery import RayQuery, PointmapEncoder, RayEncoder
from models.losses import *
import utils.misc as misc
import os
from utils.viz import just_load_viz
from utils.fusion import fuse_batch
import socket
import time
from utils.augmentations import *
def parse_args():
parser = ArgumentParser()
parser.add_argument("--dataset_train", type=str, default="TableOfCubes(size=10,n_views=2,seed=747)")
parser.add_argument("--dataset_test", type=str, default="TableOfCubes(size=10,n_views=2,seed=787)")
parser.add_argument("--dataset_just_load", type=str, default=None)
parser.add_argument("--logdir", type=str, default="logs")
parser.add_argument("--batch_size", type=int, default=5)
parser.add_argument("--n_epochs", type=int, default=100)
parser.add_argument("--n_workers", type=int, default=4)
parser.add_argument("--model", type=str, default="RayQuery(ray_enc=RayEncoder(),pointmap_enc=PointmapEncoder(),criterion=RayCompletion(ConfLoss(L21)))")
parser.add_argument("--save_every", type=int, default=1)
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--eval_every", type=int, default=3)
parser.add_argument("--wandb_project", type=str, default=None)
parser.add_argument("--wandb_run_name", type=str, default="init")
parser.add_argument("--just_load", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876"))
parser.add_argument("--mesh", action="store_true")
parser.add_argument("--max_norm", type=float, default=-1)
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--normalize_mode',type=str,default='None')
parser.add_argument('--start_from',type=str,default=None)
parser.add_argument('--augmentor',type=str,default='None')
return parser.parse_args()
def main(args):
load_dino = False
if not args.just_load:
dataset_train = eval(args.dataset_train)
dataset_test = eval(args.dataset_test)
if not dataset_train.prefetch_dino:
load_dino = True
rank, world_size, local_rank = misc.setup_distributed()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=world_size, rank=rank, shuffle=True
)
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=world_size, rank=rank, shuffle=False
)
train_loader = DataLoader(
dataset_train, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
num_workers=args.n_workers,
pin_memory=True,
prefetch_factor=2,
drop_last=True
)
test_loader = DataLoader(
dataset_test, sampler=sampler_test, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
num_workers=args.n_workers,
pin_memory=True,
prefetch_factor=2,
drop_last=True
)
n_scenes_epoch = len(train_loader) * args.batch_size * world_size
print(f"Number of scenes in epoch: {n_scenes_epoch}")
else:
if args.dataset_just_load is None:
dataset = eval(args.dataset_train)
else:
dataset = eval(args.dataset_just_load)
if not dataset.prefetch_dino:
load_dino = True
rank, world_size, local_rank = misc.setup_distributed()
sampler_train = torch.utils.data.DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=False
)
just_loader = DataLoader(dataset, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
pin_memory=True,
drop_last=True
)
model = eval(args.model).to(args.device)
if args.augmentor != 'None':
augmentor = eval(args.augmentor)
else:
augmentor = None
if load_dino and len(model.dino_layers) > 0:
dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
dino_model.eval()
dino_model.to("cuda")
else:
dino_model = None
# distribute model
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=True)
model_without_ddp = model.module if hasattr(model, 'module') else model
eff_batch_size = args.batch_size * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
os.makedirs(args.logdir,exist_ok=True)
start_epoch = 0
print("Running on host %s" % socket.gethostname())
if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-latest.pth")):
checkpoint = torch.load(os.path.join(args.resume, "checkpoint-latest.pth"), map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
model_params = list(model.parameters())
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
print("With optim & sched!")
del checkpoint
elif args.start_from is not None:
checkpoint = torch.load(args.start_from, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Start from checkpoint %s" % args.start_from)
if args.just_load:
with torch.no_grad():
while True:
#test_log_dict = eval_epoch(model,just_loader,device=args.device,dino_model=dino_model,args=args)
for data in just_loader:
pred, gt, loss_dict, batch = eval_model(model,data,mode='viz',args=args,dino_model=dino_model,augmentor=augmentor)
# cast to float32 for visualization
gt = {k: v.float() for k, v in gt.items()}
pred = {k: v.float() for k, v in pred.items()}
#loss_dict = eval_model(model,data,mode='loss',device=args.device)
#print(f"Loss: {loss_dict['loss']:.4f}")
# summarize all keys in loss_dict in table
print(f"{'Key':<10} {'Value':<10}")
print("-"*20)
for key, value in loss_dict.items():
print(f"{key:<10}: {value:.4f}")
print("-"*20)
name = args.logdir
addr = args.rr_addr
if args.mesh:
fused_meshes = fuse_batch(pred,gt,data, voxel_size=0.002)
else:
fused_meshes = None
just_load_viz(pred,gt,batch,addr=addr,name=name,fused_meshes=fused_meshes)
breakpoint()
return
else:
if args.wandb_project and misc.get_rank() == 0:
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
log_wandb = args.wandb_project
else:
log_wandb = None
for epoch in range(start_epoch,args.n_epochs):
start_time = time.time()
log_dict = train_epoch(model,train_loader,optimizer,device=args.device,max_norm=args.max_norm,epoch=epoch,
log_wandb=log_wandb,batch_size=eff_batch_size,args=args,dino_model=dino_model,augmentor=augmentor)
end_time = time.time()
print(f"Epoch {epoch} train loss: {log_dict['loss']:.4f} grad_norm: {log_dict['grad_norm']:.4f} \n")
print(f"Time taken for epoch {epoch}: {end_time - start_time:.2f} seconds")
if epoch % args.eval_every == 0:
test_log_dict = eval_epoch(model,test_loader,device=args.device,dino_model=dino_model,args=args,augmentor=augmentor)
print(f"Epoch {epoch} test loss: {test_log_dict['loss']:.4f} \n")
if log_wandb:
wandb_dict = {f"test_{k}":v for k,v in test_log_dict.items()}
wandb.log(wandb_dict, step=(epoch+1)*n_scenes_epoch)
if epoch % args.save_every == 0:
# this saves the model every epoch and doesn't overwrite but it becomes tremendous, huge
#misc.save_model(args, epoch, model, optimizer)
misc.save_model(args, epoch, model_without_ddp, optimizer, epoch_name=f"latest")
if __name__ == "__main__":
args = parse_args()
main(args)