Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| "LiftFeat: 3D Geometry-Aware Local Feature Matching" | |
| training script | |
| """ | |
| import argparse | |
| import os | |
| import time | |
| import sys | |
| sys.path.append(os.path.dirname(__file__)) | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser(description="LiftFeat training script.") | |
| parser.add_argument('--name',type=str,default='LiftFeat',help='set process name') | |
| # MegaDepth dataset setting | |
| parser.add_argument('--use_megadepth',action='store_true') | |
| parser.add_argument('--megadepth_root_path', type=str, | |
| default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548', | |
| help='Path to the MegaDepth dataset root directory.') | |
| parser.add_argument('--megadepth_batch_size', type=int, default=6) | |
| # COCO20k dataset setting | |
| parser.add_argument('--use_coco',action='store_true') | |
| parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k', | |
| help='Path to the COCO20k dataset root directory.') | |
| parser.add_argument('--coco_batch_size',type=int,default=4) | |
| parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test', | |
| help='Path to save the checkpoints.') | |
| parser.add_argument('--n_steps', type=int, default=160_000, | |
| help='Number of training steps. Default is 160000.') | |
| parser.add_argument('--lr', type=float, default=3e-4, | |
| help='Learning rate. Default is 0.0003.') | |
| parser.add_argument('--gamma_steplr', type=float, default=0.5, | |
| help='Gamma value for StepLR scheduler. Default is 0.5.') | |
| parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))), | |
| default=(800, 608), help='Training resolution as width,height. Default is (800, 608).') | |
| parser.add_argument('--device_num', type=str, default='0', | |
| help='Device number to use for training. Default is "0".') | |
| parser.add_argument('--dry_run', action='store_true', | |
| help='If set, perform a dry run training with a mini-batch for sanity check.') | |
| parser.add_argument('--save_ckpt_every', type=int, default=500, | |
| help='Save checkpoints every N steps. Default is 500.') | |
| args = parser.parse_args() | |
| os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num | |
| return args | |
| args = parse_arguments() | |
| import torch | |
| from torch import nn | |
| from torch import optim | |
| import torch.nn.functional as F | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| import tqdm | |
| import glob | |
| from models.model import LiftFeatSPModel | |
| from loss.loss import LiftFeatLoss | |
| from utils.config import featureboost_config | |
| from models.interpolator import InterpolateSparse2d | |
| from utils.depth_anything_wrapper import DepthAnythingExtractor | |
| from utils.alike_wrapper import ALikeExtractor | |
| from dataset import megadepth_wrapper | |
| from dataset import coco_wrapper | |
| from dataset.megadepth import MegaDepthDataset | |
| from dataset.coco_augmentor import COCOAugmentor | |
| import setproctitle | |
| class Trainer(): | |
| def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size, | |
| coco_root_path,use_coco,coco_batch_size, | |
| ckpt_save_path, | |
| model_name = 'LiftFeat', | |
| n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5, | |
| training_res = (800, 608), device_num="0", dry_run = False, | |
| save_ckpt_every = 500): | |
| print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}') | |
| print(f'COCO20k: {use_coco}-{coco_batch_size}') | |
| self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu') | |
| # training model | |
| self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev) | |
| self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1) | |
| # depth-anything model | |
| self.depth_net=DepthAnythingExtractor('vits',self.dev,256) | |
| # alike model | |
| self.alike_net=ALikeExtractor('alike-t',self.dev) | |
| #Setup optimizer | |
| self.steps = n_steps | |
| self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr) | |
| self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr) | |
| ##################### COCO INIT ########################## | |
| self.use_coco=use_coco | |
| self.coco_batch_size=coco_batch_size | |
| if self.use_coco: | |
| self.augmentor=COCOAugmentor( | |
| img_dir=coco_root_path, | |
| device=self.dev,load_dataset=True, | |
| batch_size=self.coco_batch_size, | |
| out_resolution=training_res, | |
| warp_resolution=training_res, | |
| sides_crop=0.1, | |
| max_num_imgs=3000, | |
| num_test_imgs=5, | |
| photometric=True, | |
| geometric=True, | |
| reload_step=4000 | |
| ) | |
| ##################### COCO END ####################### | |
| ##################### MEGADEPTH INIT ########################## | |
| self.use_megadepth=use_megadepth | |
| self.megadepth_batch_size=megadepth_batch_size | |
| if self.use_megadepth: | |
| TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices" | |
| TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1" | |
| TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" | |
| npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] | |
| megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE, | |
| npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] ) | |
| self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True) | |
| self.megadepth_data_iter = iter(self.megadepth_dataloader) | |
| ##################### MEGADEPTH INIT END ####################### | |
| os.makedirs(ckpt_save_path, exist_ok=True) | |
| os.makedirs(ckpt_save_path + '/logdir', exist_ok=True) | |
| self.dry_run = dry_run | |
| self.save_ckpt_every = save_ckpt_every | |
| self.ckpt_save_path = ckpt_save_path | |
| self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) | |
| self.model_name = model_name | |
| def generate_train_data(self): | |
| imgs1_t,imgs2_t=[],[] | |
| imgs1_np,imgs2_np=[],[] | |
| # norms0,norms1=[],[] | |
| positives_coarse=[] | |
| if self.use_coco: | |
| coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1) | |
| h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8 | |
| _ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse) | |
| coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True) | |
| imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2) | |
| positives_coarse += positives_coco_coarse | |
| if self.use_megadepth: | |
| try: | |
| megadepth_data=next(self.megadepth_data_iter) | |
| except StopIteration: | |
| print('End of MD DATASET') | |
| self.megadepth_data_iter=iter(self.megadepth_dataloader) | |
| megadepth_data=next(self.megadepth_data_iter) | |
| if megadepth_data is not None: | |
| for k in megadepth_data.keys(): | |
| if isinstance(megadepth_data[k],torch.Tensor): | |
| megadepth_data[k]=megadepth_data[k].to(self.dev) | |
| megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1'] | |
| megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True) | |
| imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t) | |
| megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np'] | |
| for np_idx in range(megadepth_imgs1_np.shape[0]): | |
| img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy() | |
| imgs1_np.append(img1_np);imgs2_np.append(img2_np) | |
| positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8) | |
| positives_coarse += positives_megadepth_coarse | |
| with torch.no_grad(): | |
| imgs1_t=torch.cat(imgs1_t,dim=0) | |
| imgs2_t=torch.cat(imgs2_t,dim=0) | |
| return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse | |
| def train(self): | |
| self.net.train() | |
| with tqdm.tqdm(total=self.steps) as pbar: | |
| for i in range(self.steps): | |
| # import pdb;pdb.set_trace() | |
| imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data() | |
| #Check if batch is corrupted with too few correspondences | |
| is_corrupted = False | |
| for p in positives_coarse: | |
| if len(p) < 30: | |
| is_corrupted = True | |
| if is_corrupted: | |
| continue | |
| # import pdb;pdb.set_trace() | |
| #Forward pass | |
| # start=time.perf_counter() | |
| feats1,kpts1,normals1 = self.net.forward1(imgs1_t) | |
| feats2,kpts2,normals2 = self.net.forward1(imgs2_t) | |
| coordinates,fb_coordinates=[],[] | |
| alike_kpts1,alike_kpts2=[],[] | |
| DA_normals1,DA_normals2=[],[] | |
| # import pdb;pdb.set_trace() | |
| fb_feats1,fb_feats2=[],[] | |
| for b in range(feats1.shape[0]): | |
| feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1]) | |
| feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1]) | |
| coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1)) | |
| coordinates.append(coordinate) | |
| fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0)) | |
| fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0)) | |
| fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1)) | |
| fb_coordinates.append(fb_coordinate) | |
| fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0)) | |
| img1,img2=imgs1_t[b],imgs2_t[b] | |
| img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255 | |
| img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255 | |
| alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev) | |
| alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev) | |
| alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2) | |
| # import pdb;pdb.set_trace() | |
| for b in range(len(imgs1_np)): | |
| megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b]) | |
| megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b]) | |
| DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2) | |
| # import pdb;pdb.set_trace() | |
| fb_feats1=torch.cat(fb_feats1,dim=0) | |
| fb_feats2=torch.cat(fb_feats2,dim=0) | |
| fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) | |
| fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2) | |
| coordinates=torch.cat(coordinates,dim=0) | |
| coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) | |
| fb_coordinates=torch.cat(fb_coordinates,dim=0) | |
| fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) | |
| # end=time.perf_counter() | |
| # print(f"forward1 cost {end-start} seconds") | |
| loss_items = [] | |
| # import pdb;pdb.set_trace() | |
| loss_info=self.loss_fn( | |
| feats1,fb_feats1,kpts1,normals1, | |
| feats2,fb_feats2,kpts2,normals2, | |
| positives_coarse, | |
| coordinates,fb_coordinates, | |
| alike_kpts1,alike_kpts2, | |
| DA_normals1,DA_normals2, | |
| self.megadepth_batch_size,self.coco_batch_size) | |
| loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse'] | |
| loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates'] | |
| loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse'] | |
| loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates'] | |
| loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt'] | |
| loss_normals=loss_info['loss_normals'] | |
| # loss_items.append(loss_descs.unsqueeze(0)) | |
| # loss_items.append(loss_coordinates.unsqueeze(0)) | |
| loss_items.append(loss_fb_descs.unsqueeze(0)) | |
| loss_items.append(loss_fb_coordinates.unsqueeze(0)) | |
| loss_items.append(loss_kpts.unsqueeze(0)) | |
| loss_items.append(loss_normals.unsqueeze(0)) | |
| # nb_coarse = len(m1) | |
| # nb_coarse = len(fb_m1) | |
| loss = torch.cat(loss_items, -1).mean() | |
| # Compute Backward Pass | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.) | |
| self.opt.step() | |
| self.opt.zero_grad() | |
| self.scheduler.step() | |
| # import pdb;pdb.set_trace() | |
| if (i+1) % self.save_ckpt_every == 0: | |
| print('saving iter ', i+1) | |
| torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth') | |
| pbar.set_description( | |
| 'Loss: {:.4f} \ | |
| loss_descs: {:.3f} acc_coarse: {:.3f} \ | |
| loss_coordinates: {:.3f} acc_coordinates: {:.3f} \ | |
| loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \ | |
| loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \ | |
| loss_kpts: {:.3f} acc_kpts: {:.3f} \ | |
| loss_normals: {:.3f}'.format( \ | |
| loss.item(), \ | |
| loss_descs.item(), acc_coarse, \ | |
| loss_coordinates.item(), acc_coordinates, \ | |
| loss_fb_descs.item(), acc_fb_coarse, \ | |
| loss_fb_coordinates.item(), acc_fb_coordinates, \ | |
| loss_kpts.item(), acc_kpt, \ | |
| loss_normals.item()) ) | |
| pbar.update(1) | |
| # Log metrics | |
| self.writer.add_scalar('Loss/total', loss.item(), i) | |
| self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i) | |
| self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i) | |
| self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i) | |
| self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i) | |
| self.writer.add_scalar('Loss/descs', loss_descs.item(), i) | |
| self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i) | |
| self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i) | |
| self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i) | |
| self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i) | |
| self.writer.add_scalar('Loss/normals', loss_normals.item(), i) | |
| if __name__ == '__main__': | |
| setproctitle.setproctitle(args.name) | |
| trainer = Trainer( | |
| megadepth_root_path=args.megadepth_root_path, | |
| use_megadepth=args.use_megadepth, | |
| megadepth_batch_size=args.megadepth_batch_size, | |
| coco_root_path=args.coco_root_path, | |
| use_coco=args.use_coco, | |
| coco_batch_size=args.coco_batch_size, | |
| ckpt_save_path=args.ckpt_save_path, | |
| n_steps=args.n_steps, | |
| lr=args.lr, | |
| gamma_steplr=args.gamma_steplr, | |
| training_res=args.training_res, | |
| device_num=args.device_num, | |
| dry_run=args.dry_run, | |
| save_ckpt_every=args.save_ckpt_every | |
| ) | |
| #The most fun part | |
| trainer.train() | |