rayst3r / engine.py
bartduis's picture
init
70d1188
raw
history blame
5.84 kB
bb=breakpoint
import torch
from utils.geometry import center_pointmaps, uncenter_pointmaps
from utils.utils import scenes_to_batch, batch_to_scenes
from utils.batch_prep import prepare_fast_batch, normalize_batch, denormalize_batch
from utils.viz import save_pointmaps
from tqdm import tqdm
import wandb
from utils import misc
from torch.amp import GradScaler
from utils.eval import eval_pred
from utils.geometry import depth2pts
def batch_to_device(batch,device='cuda'):
for key in batch:
if isinstance(batch[key],torch.Tensor):
batch[key] = batch[key].to(device)
elif isinstance(batch[key],dict):
batch[key] = batch_to_device(batch[key],device)
return batch
def eval_model(model,batch,mode='loss',device='cuda',dino_model=None,args=None,augmentor=None,return_scale=False):
batch = batch_to_device(batch,device)
# check if model is distributed
if isinstance(model,torch.nn.parallel.DistributedDataParallel):
dino_layers = model.module.dino_layers
else:
dino_layers = model.dino_layers
if 'pointmaps' not in list(batch['input_cams'].keys()):
batch = prepare_fast_batch(batch,dino_model,dino_layers)
normalize_mode = args.normalize_mode if args is not None else 'median'
batch, scale_factors = normalize_batch(batch,normalize_mode)
if augmentor is not None:
batch = augmentor(batch)
batch, n_cams = scenes_to_batch(batch)
batch = center_pointmaps(batch) # centering around first camera
device = args.device if args is not None else 'cuda'
with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
pred, gt, loss_dict = model(batch,mode='viz')
if 'pointmaps' not in list(pred.keys()):
pred['pointmaps'] = depth2pts(pred['depths'].squeeze(-1),batch['new_cams']['Ks'])
elif 'depths' not in list(pred.keys()):
pred['depths'] = pred['pointmaps'][...,-1]
loss_dict = {**loss_dict,**eval_pred(pred, gt)}
if mode == 'loss':
return loss_dict
elif mode == 'viz':
pred, gt, batch = uncenter_pointmaps(pred, gt, batch)
pred, gt, batch = batch_to_scenes(pred, gt,batch, n_cams)
if return_scale:
return pred, gt, loss_dict, scale_factors[0].item()
else:
return pred, gt, loss_dict
else:
raise ValueError(f"Invalid mode: {mode}")
def update_loss_dict(loss_dict,loss_dict_new,sample_count):
for key in loss_dict_new:
if key not in loss_dict:
loss_dict[key] = loss_dict_new[key]
else:
# previously stored value in loss_dict is average from sample_count samples
# so we need to update it to include the new sample
loss_dict[key] = (loss_dict[key] * sample_count + loss_dict_new[key]) / (sample_count + 1)
return loss_dict
def train_epoch(model, train_loader, optimizer, device='cuda', max_norm=1.0,log_wandb=False,epoch=0,batch_size=None,args=None,dino_model=None,augmentor=None):
model.train()
all_losses_dict = {}
sample_idx = epoch * batch_size * len(train_loader)
scaler = GradScaler()
for i, batch in tqdm(enumerate(train_loader),total=len(train_loader)):
optimizer.zero_grad()
new_loss_dict = eval_model(model, batch, mode='loss', device=device,dino_model=dino_model,args=args,augmentor=augmentor)
loss = new_loss_dict['loss']
if loss is None:
continue
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
grad_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None]))
if grad_norm.isnan():
breakpoint()
## Since the gradients of optimizer's assigned params are unscaled, clips as usual:
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
new_loss_dict['grad_norm'] = grad_norm.detach().cpu().item()
misc.adjust_learning_rate(optimizer, epoch + i/len(train_loader), args)
optimizer.step()
new_loss_dict = {k: (v.detach().cpu().item() if isinstance(v, torch.Tensor) else v) for k, v in new_loss_dict.items()}
if log_wandb:
wandb_dict = {f"train_{k}":v for k,v in new_loss_dict.items()}
wandb.log(wandb_dict, step=sample_idx + (i+1)*batch_size)
# log learning rate
wandb.log({"train_lr": optimizer.param_groups[0]['lr']}, step=sample_idx + (i+1)*batch_size)
all_losses_dict = update_loss_dict(all_losses_dict, new_loss_dict,sample_count=i)
# Clear cache and delete variables to free memory
torch.cuda.empty_cache()
del loss
del new_loss_dict
del grad_norm
del batch
return all_losses_dict
def eval_epoch(model,test_loader,device='cuda',dino_model=None,args=None,augmentor=None):
model.eval()
all_losses_dict = {}
with torch.no_grad():
for i, batch in tqdm(enumerate(test_loader),total=len(test_loader)):
new_loss_dict = eval_model(model,batch,mode='loss',device=device,dino_model=dino_model,args=args,augmentor=augmentor)
if new_loss_dict is None:
continue
all_losses_dict = update_loss_dict(all_losses_dict,new_loss_dict,sample_count=i)
torch.cuda.empty_cache()
del new_loss_dict
del batch
return all_losses_dict