File size: 5,859 Bytes
70d1188 db4b25c 70d1188 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
bb=breakpoint
import torch
from utils.batch_prep import prepare_fast_batch, normalize_batch, denormalize_batch
from utils.utils import scenes_to_batch, batch_to_scenes
from utils.geometry import center_pointmaps, uncenter_pointmaps
# from utils.viz import save_pointmaps
# from tqdm import tqdm
# import wandb
# from utils import misc
# from torch.amp import GradScaler
# from utils.eval import eval_pred
# from utils.geometry import depth2pts
def batch_to_device(batch,device='cuda'):
for key in batch:
if isinstance(batch[key],torch.Tensor):
batch[key] = batch[key].to(device)
elif isinstance(batch[key],dict):
batch[key] = batch_to_device(batch[key],device)
return batch
def eval_model(model,batch,mode='loss',device='cuda',dino_model=None,args=None,augmentor=None,return_scale=False):
batch = batch_to_device(batch,device)
# check if model is distributed
if isinstance(model,torch.nn.parallel.DistributedDataParallel):
dino_layers = model.module.dino_layers
else:
dino_layers = model.dino_layers
if 'pointmaps' not in list(batch['input_cams'].keys()):
batch = prepare_fast_batch(batch,dino_model,dino_layers)
normalize_mode = args.normalize_mode if args is not None else 'median'
batch, scale_factors = normalize_batch(batch,normalize_mode)
if augmentor is not None:
batch = augmentor(batch)
batch, n_cams = scenes_to_batch(batch)
batch = center_pointmaps(batch) # centering around first camera
device = args.device if args is not None else 'cuda'
with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
pred, gt, loss_dict = model(batch,mode='viz')
if 'pointmaps' not in list(pred.keys()):
pred['pointmaps'] = depth2pts(pred['depths'].squeeze(-1),batch['new_cams']['Ks'])
elif 'depths' not in list(pred.keys()):
pred['depths'] = pred['pointmaps'][...,-1]
loss_dict = {**loss_dict,**eval_pred(pred, gt)}
if mode == 'loss':
return loss_dict
elif mode == 'viz':
pred, gt, batch = uncenter_pointmaps(pred, gt, batch)
pred, gt, batch = batch_to_scenes(pred, gt,batch, n_cams)
if return_scale:
return pred, gt, loss_dict, scale_factors[0].item()
else:
return pred, gt, loss_dict
else:
raise ValueError(f"Invalid mode: {mode}")
def update_loss_dict(loss_dict,loss_dict_new,sample_count):
for key in loss_dict_new:
if key not in loss_dict:
loss_dict[key] = loss_dict_new[key]
else:
# previously stored value in loss_dict is average from sample_count samples
# so we need to update it to include the new sample
loss_dict[key] = (loss_dict[key] * sample_count + loss_dict_new[key]) / (sample_count + 1)
return loss_dict
def train_epoch(model, train_loader, optimizer, device='cuda', max_norm=1.0,log_wandb=False,epoch=0,batch_size=None,args=None,dino_model=None,augmentor=None):
model.train()
all_losses_dict = {}
sample_idx = epoch * batch_size * len(train_loader)
scaler = GradScaler()
for i, batch in tqdm(enumerate(train_loader),total=len(train_loader)):
optimizer.zero_grad()
new_loss_dict = eval_model(model, batch, mode='loss', device=device,dino_model=dino_model,args=args,augmentor=augmentor)
loss = new_loss_dict['loss']
if loss is None:
continue
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
grad_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None]))
if grad_norm.isnan():
breakpoint()
## Since the gradients of optimizer's assigned params are unscaled, clips as usual:
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
new_loss_dict['grad_norm'] = grad_norm.detach().cpu().item()
misc.adjust_learning_rate(optimizer, epoch + i/len(train_loader), args)
optimizer.step()
new_loss_dict = {k: (v.detach().cpu().item() if isinstance(v, torch.Tensor) else v) for k, v in new_loss_dict.items()}
if log_wandb:
wandb_dict = {f"train_{k}":v for k,v in new_loss_dict.items()}
wandb.log(wandb_dict, step=sample_idx + (i+1)*batch_size)
# log learning rate
wandb.log({"train_lr": optimizer.param_groups[0]['lr']}, step=sample_idx + (i+1)*batch_size)
all_losses_dict = update_loss_dict(all_losses_dict, new_loss_dict,sample_count=i)
# Clear cache and delete variables to free memory
torch.cuda.empty_cache()
del loss
del new_loss_dict
del grad_norm
del batch
return all_losses_dict
def eval_epoch(model,test_loader,device='cuda',dino_model=None,args=None,augmentor=None):
model.eval()
all_losses_dict = {}
with torch.no_grad():
for i, batch in tqdm(enumerate(test_loader),total=len(test_loader)):
new_loss_dict = eval_model(model,batch,mode='loss',device=device,dino_model=dino_model,args=args,augmentor=augmentor)
if new_loss_dict is None:
continue
all_losses_dict = update_loss_dict(all_losses_dict,new_loss_dict,sample_count=i)
torch.cuda.empty_cache()
del new_loss_dict
del batch
return all_losses_dict |