File size: 5,859 Bytes
70d1188
 
 
db4b25c
 
 
 
 
 
 
 
 
 
 
70d1188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
bb=breakpoint
import torch
from utils.batch_prep import prepare_fast_batch, normalize_batch, denormalize_batch
from utils.utils import scenes_to_batch, batch_to_scenes
from utils.geometry import center_pointmaps, uncenter_pointmaps


# from utils.viz import save_pointmaps
# from tqdm import tqdm
# import wandb
# from utils import misc
# from torch.amp import GradScaler
# from utils.eval import eval_pred
# from utils.geometry import depth2pts

def batch_to_device(batch,device='cuda'):
    for key in batch:
        if isinstance(batch[key],torch.Tensor):
            batch[key] = batch[key].to(device)
        elif isinstance(batch[key],dict):
            batch[key] = batch_to_device(batch[key],device)
    return batch

def eval_model(model,batch,mode='loss',device='cuda',dino_model=None,args=None,augmentor=None,return_scale=False):
    batch = batch_to_device(batch,device)
    # check if model is distributed
    if isinstance(model,torch.nn.parallel.DistributedDataParallel):
        dino_layers = model.module.dino_layers
    else:
        dino_layers = model.dino_layers
    if 'pointmaps' not in list(batch['input_cams'].keys()):
        batch = prepare_fast_batch(batch,dino_model,dino_layers)

    normalize_mode = args.normalize_mode if args is not None else 'median'
    batch, scale_factors = normalize_batch(batch,normalize_mode)
    if augmentor is not None:
        batch = augmentor(batch)

    batch, n_cams = scenes_to_batch(batch)
    batch = center_pointmaps(batch) # centering around first camera

    device = args.device if args is not None else 'cuda'
    with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
        pred, gt, loss_dict = model(batch,mode='viz')
    
    if 'pointmaps' not in list(pred.keys()):
        pred['pointmaps'] = depth2pts(pred['depths'].squeeze(-1),batch['new_cams']['Ks'])
    elif 'depths' not in list(pred.keys()):
        pred['depths'] = pred['pointmaps'][...,-1]
    loss_dict = {**loss_dict,**eval_pred(pred, gt)}
    if mode == 'loss':
        return loss_dict
    elif mode == 'viz':
        pred, gt, batch = uncenter_pointmaps(pred, gt, batch)
        pred, gt, batch = batch_to_scenes(pred, gt,batch, n_cams)
        if return_scale:
            return pred, gt, loss_dict, scale_factors[0].item()
        else:
            return pred, gt, loss_dict
    else:
        raise ValueError(f"Invalid mode: {mode}")

def update_loss_dict(loss_dict,loss_dict_new,sample_count):
    for key in loss_dict_new:
        if key not in loss_dict:
            loss_dict[key] = loss_dict_new[key]
        else:
            # previously stored value in loss_dict is average from sample_count samples
            # so we need to update it to include the new sample
            loss_dict[key] = (loss_dict[key] * sample_count + loss_dict_new[key]) / (sample_count + 1)
    return loss_dict

def train_epoch(model, train_loader, optimizer, device='cuda', max_norm=1.0,log_wandb=False,epoch=0,batch_size=None,args=None,dino_model=None,augmentor=None):
    model.train()
    all_losses_dict = {}
    
    sample_idx = epoch * batch_size * len(train_loader)
    scaler = GradScaler()
    for i, batch in tqdm(enumerate(train_loader),total=len(train_loader)):
        optimizer.zero_grad()
        new_loss_dict = eval_model(model, batch, mode='loss', device=device,dino_model=dino_model,args=args,augmentor=augmentor)
        loss = new_loss_dict['loss']
        if loss is None:
            continue
        
        scaler.scale(loss).backward()
        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)
        
        grad_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None]))
        if grad_norm.isnan():
            breakpoint()
        
        ## Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        
        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        new_loss_dict['grad_norm'] = grad_norm.detach().cpu().item()

        misc.adjust_learning_rate(optimizer, epoch + i/len(train_loader), args)
        optimizer.step()
        
        new_loss_dict = {k: (v.detach().cpu().item() if isinstance(v, torch.Tensor) else v) for k, v in new_loss_dict.items()}
        if log_wandb:
            wandb_dict = {f"train_{k}":v for k,v in new_loss_dict.items()}
            wandb.log(wandb_dict, step=sample_idx + (i+1)*batch_size)
            # log learning rate
            wandb.log({"train_lr": optimizer.param_groups[0]['lr']}, step=sample_idx + (i+1)*batch_size)
        
        all_losses_dict = update_loss_dict(all_losses_dict, new_loss_dict,sample_count=i)
        # Clear cache and delete variables to free memory
        torch.cuda.empty_cache()
        del loss
        del new_loss_dict
        del grad_norm
        del batch
    
    return all_losses_dict

def eval_epoch(model,test_loader,device='cuda',dino_model=None,args=None,augmentor=None):
    model.eval()
    all_losses_dict = {}
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_loader),total=len(test_loader)):
            new_loss_dict = eval_model(model,batch,mode='loss',device=device,dino_model=dino_model,args=args,augmentor=augmentor)
            if new_loss_dict is None:
                continue
            all_losses_dict = update_loss_dict(all_losses_dict,new_loss_dict,sample_count=i)

            torch.cuda.empty_cache()
            del new_loss_dict
            del batch
    
    return all_losses_dict