Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import time | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.cuda.amp import autocast | |
| import hydra | |
| from omegaconf import DictConfig, OmegaConf | |
| from hydra.utils import instantiate | |
| from lightglue import LightGlue, SuperPoint, SIFT, ALIKED | |
| import pycolmap | |
| from visdom import Visdom | |
| from vggsfm.datasets.demo_loader import DemoLoader | |
| from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras | |
| try: | |
| import poselib | |
| from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras_poselib | |
| print("Poselib is available") | |
| except: | |
| print("Poselib is not installed. Please disable use_poselib") | |
| from vggsfm.utils.utils import ( | |
| set_seed_and_print, | |
| farthest_point_sampling, | |
| calculate_index_mappings, | |
| switch_tensor_order, | |
| ) | |
| def demo_fn(cfg: DictConfig): | |
| OmegaConf.set_struct(cfg, False) | |
| # Print configuration | |
| print("Model Config:", OmegaConf.to_yaml(cfg)) | |
| torch.backends.cudnn.enabled = False | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = True | |
| # Set seed | |
| seed_all_random_engines(cfg.seed) | |
| # Model instantiation | |
| model = instantiate(cfg.MODEL, _recursive_=False, cfg=cfg) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| # Prepare test dataset | |
| test_dataset = DemoLoader( | |
| SCENE_DIR=cfg.SCENE_DIR, img_size=cfg.img_size, normalize_cameras=False, load_gt=cfg.load_gt, cfg=cfg | |
| ) | |
| if cfg.resume_ckpt: | |
| # Reload model | |
| checkpoint = torch.load(cfg.resume_ckpt) | |
| model.load_state_dict(checkpoint, strict=True) | |
| print(f"Successfully resumed from {cfg.resume_ckpt}") | |
| if cfg.visualize: | |
| from pytorch3d.structures import Pointclouds | |
| from pytorch3d.vis.plotly_vis import plot_scene | |
| from pytorch3d.renderer.cameras import PerspectiveCameras as PerspectiveCamerasVisual | |
| viz = Visdom() | |
| sequence_list = test_dataset.sequence_list | |
| for seq_name in sequence_list: | |
| print("*" * 50 + f" Testing on Scene {seq_name} " + "*" * 50) | |
| # Load the data | |
| batch, image_paths = test_dataset.get_data(sequence_name=seq_name, return_path=True) | |
| # Send to GPU | |
| images = batch["image"].to(device) | |
| crop_params = batch["crop_params"].to(device) | |
| # Unsqueeze to have batch size = 1 | |
| images = images.unsqueeze(0) | |
| crop_params = crop_params.unsqueeze(0) | |
| batch_size = len(images) | |
| with torch.no_grad(): | |
| # Run the model | |
| assert cfg.mixed_precision in ("None", "bf16", "fp16") | |
| if cfg.mixed_precision == "None": | |
| dtype = torch.float32 | |
| elif cfg.mixed_precision == "bf16": | |
| dtype = torch.bfloat16 | |
| elif cfg.mixed_precision == "fp16": | |
| dtype = torch.float16 | |
| else: | |
| raise NotImplementedError(f"dtype {cfg.mixed_precision} is not supported now") | |
| predictions = run_one_scene( | |
| model, | |
| images, | |
| crop_params=crop_params, | |
| query_frame_num=cfg.query_frame_num, | |
| image_paths=image_paths, | |
| dtype=dtype, | |
| cfg=cfg, | |
| ) | |
| # Export prediction as colmap format | |
| reconstruction_pycolmap = predictions["reconstruction"] | |
| output_path = os.path.join("output", seq_name) | |
| print("-" * 50) | |
| print(f"The output has been saved in COLMAP style at: {output_path} ") | |
| os.makedirs(output_path, exist_ok=True) | |
| reconstruction_pycolmap.write(output_path) | |
| pred_cameras_PT3D = predictions["pred_cameras_PT3D"] | |
| if cfg.visualize: | |
| if "points3D_rgb" in predictions: | |
| pcl = Pointclouds(points=predictions["points3D"][None], features=predictions["points3D_rgb"][None]) | |
| else: | |
| pcl = Pointclouds(points=predictions["points3D"][None]) | |
| visual_cameras = PerspectiveCamerasVisual( | |
| R=pred_cameras_PT3D.R, | |
| T=pred_cameras_PT3D.T, | |
| device=pred_cameras_PT3D.device, | |
| ) | |
| visual_dict = {"scenes": {"points": pcl, "cameras": visual_cameras}} | |
| fig = plot_scene(visual_dict, camera_scale=0.05) | |
| env_name = f"demo_visual_{seq_name}" | |
| print(f"Visualizing the scene by visdom at env: {env_name}") | |
| viz.plotlyplot(fig, env=env_name, win="3D") | |
| return True | |
| def run_one_scene(model, images, crop_params=None, query_frame_num=3, image_paths=None, dtype=None, cfg=None): | |
| """ | |
| images have been normalized to the range [0, 1] instead of [0, 255] | |
| """ | |
| batch_num, frame_num, image_dim, height, width = images.shape | |
| device = images.device | |
| reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) | |
| predictions = {} | |
| extra_dict = {} | |
| camera_predictor = model.camera_predictor | |
| track_predictor = model.track_predictor | |
| triangulator = model.triangulator | |
| # Find the query frames | |
| # First use DINO to find the most common frame among all the input frames | |
| # i.e., the one has highest (average) cosine similarity to all others | |
| # Then use farthest_point_sampling to find the next ones | |
| # The number of query frames is determined by query_frame_num | |
| with autocast(dtype=dtype): | |
| query_frame_indexes = find_query_frame_indexes(reshaped_image, camera_predictor, frame_num) | |
| image_paths = [os.path.basename(imgpath) for imgpath in image_paths] | |
| if cfg.center_order: | |
| # The code below switchs the first frame (frame 0) to the most common frame | |
| center_frame_index = query_frame_indexes[0] | |
| center_order = calculate_index_mappings(center_frame_index, frame_num, device=device) | |
| images, crop_params = switch_tensor_order([images, crop_params], center_order, dim=1) | |
| reshaped_image = switch_tensor_order([reshaped_image], center_order, dim=0)[0] | |
| image_paths = [image_paths[i] for i in center_order.cpu().numpy().tolist()] | |
| # Also update query_frame_indexes: | |
| query_frame_indexes = [center_frame_index if x == 0 else x for x in query_frame_indexes] | |
| query_frame_indexes[0] = 0 | |
| # only pick query_frame_num | |
| query_frame_indexes = query_frame_indexes[:query_frame_num] | |
| # Prepare image feature maps for tracker | |
| fmaps_for_tracker = track_predictor.process_images_to_fmaps(images) | |
| # Predict tracks | |
| with autocast(dtype=dtype): | |
| pred_track, pred_vis, pred_score = predict_tracks( | |
| cfg.query_method, | |
| cfg.max_query_pts, | |
| track_predictor, | |
| images, | |
| fmaps_for_tracker, | |
| query_frame_indexes, | |
| frame_num, | |
| device, | |
| cfg, | |
| ) | |
| if cfg.comple_nonvis: | |
| pred_track, pred_vis, pred_score = comple_nonvis_frames( | |
| track_predictor, | |
| images, | |
| fmaps_for_tracker, | |
| frame_num, | |
| device, | |
| pred_track, | |
| pred_vis, | |
| pred_score, | |
| 500, | |
| cfg=cfg, | |
| ) | |
| torch.cuda.empty_cache() | |
| # If necessary, force all the predictions at the padding areas as non-visible | |
| if crop_params is not None: | |
| boundaries = crop_params[:, :, -4:-2].abs().to(device) | |
| boundaries = torch.cat([boundaries, reshaped_image.shape[-1] - boundaries], dim=-1) | |
| hvis = torch.logical_and( | |
| pred_track[..., 1] >= boundaries[:, :, 1:2], pred_track[..., 1] <= boundaries[:, :, 3:4] | |
| ) | |
| wvis = torch.logical_and( | |
| pred_track[..., 0] >= boundaries[:, :, 0:1], pred_track[..., 0] <= boundaries[:, :, 2:3] | |
| ) | |
| force_vis = torch.logical_and(hvis, wvis) | |
| pred_vis = pred_vis * force_vis.float() | |
| # TODO: plot 2D matches | |
| if cfg.use_poselib: | |
| estimate_preliminary_cameras_fn = estimate_preliminary_cameras_poselib | |
| else: | |
| estimate_preliminary_cameras_fn = estimate_preliminary_cameras | |
| # Estimate preliminary_cameras by recovering fundamental/essential/homography matrix from 2D matches | |
| # By default, we use fundamental matrix estimation with 7p/8p+LORANSAC | |
| # All the operations are batched and differentiable (if necessary) | |
| # except when you enable use_poselib to save GPU memory | |
| _, preliminary_dict = estimate_preliminary_cameras_fn( | |
| pred_track, | |
| pred_vis, | |
| width, | |
| height, | |
| tracks_score=pred_score, | |
| max_error=cfg.fmat_thres, | |
| loopresidual=True, | |
| # max_ransac_iters=cfg.max_ransac_iters, | |
| ) | |
| pose_predictions = camera_predictor(reshaped_image, batch_size=batch_num) | |
| pred_cameras = pose_predictions["pred_cameras"] | |
| # Conduct Triangulation and Bundle Adjustment | |
| ( | |
| BA_cameras_PT3D, | |
| extrinsics_opencv, | |
| intrinsics_opencv, | |
| points3D, | |
| points3D_rgb, | |
| reconstruction, | |
| valid_frame_mask, | |
| ) = triangulator( | |
| pred_cameras, | |
| pred_track, | |
| pred_vis, | |
| images, | |
| preliminary_dict, | |
| image_paths=image_paths, | |
| crop_params=crop_params, | |
| pred_score=pred_score, | |
| fmat_thres=cfg.fmat_thres, | |
| BA_iters=cfg.BA_iters, | |
| max_reproj_error = cfg.max_reproj_error, | |
| init_max_reproj_error=cfg.init_max_reproj_error, | |
| cfg=cfg, | |
| ) | |
| if cfg.center_order: | |
| # NOTE we changed the image order previously, now we need to switch it back | |
| BA_cameras_PT3D = BA_cameras_PT3D[center_order] | |
| extrinsics_opencv = extrinsics_opencv[center_order] | |
| intrinsics_opencv = intrinsics_opencv[center_order] | |
| predictions["pred_cameras_PT3D"] = BA_cameras_PT3D | |
| predictions["extrinsics_opencv"] = extrinsics_opencv | |
| predictions["intrinsics_opencv"] = intrinsics_opencv | |
| predictions["points3D"] = points3D | |
| predictions["points3D_rgb"] = points3D_rgb | |
| predictions["reconstruction"] = reconstruction | |
| return predictions | |
| def predict_tracks( | |
| query_method, | |
| max_query_pts, | |
| track_predictor, | |
| images, | |
| fmaps_for_tracker, | |
| query_frame_indexes, | |
| frame_num, | |
| device, | |
| cfg=None, | |
| ): | |
| pred_track_list = [] | |
| pred_vis_list = [] | |
| pred_score_list = [] | |
| for query_index in query_frame_indexes: | |
| print(f"Predicting tracks with query_index = {query_index}") | |
| # Find query_points at the query frame | |
| query_points = get_query_points(images[:, query_index], query_method, max_query_pts) | |
| # Switch so that query_index frame stays at the first frame | |
| # This largely simplifies the code structure of tracker | |
| new_order = calculate_index_mappings(query_index, frame_num, device=device) | |
| images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], new_order) | |
| # Feed into track predictor | |
| fine_pred_track, _, pred_vis, pred_score = track_predictor(images_feed, query_points, fmaps=fmaps_feed) | |
| # Switch back the predictions | |
| fine_pred_track, pred_vis, pred_score = switch_tensor_order([fine_pred_track, pred_vis, pred_score], new_order) | |
| # Append predictions for different queries | |
| pred_track_list.append(fine_pred_track) | |
| pred_vis_list.append(pred_vis) | |
| pred_score_list.append(pred_score) | |
| pred_track = torch.cat(pred_track_list, dim=2) | |
| pred_vis = torch.cat(pred_vis_list, dim=2) | |
| pred_score = torch.cat(pred_score_list, dim=2) | |
| return pred_track, pred_vis, pred_score | |
| def comple_nonvis_frames( | |
| track_predictor, | |
| images, | |
| fmaps_for_tracker, | |
| frame_num, | |
| device, | |
| pred_track, | |
| pred_vis, | |
| pred_score, | |
| min_vis=500, | |
| cfg=None, | |
| ): | |
| # if a frame has too few visible inlier, use it as a query | |
| non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() | |
| last_query = -1 | |
| while len(non_vis_frames) > 0: | |
| print("Processing non visible frames") | |
| print(non_vis_frames) | |
| if non_vis_frames[0] == last_query: | |
| print("The non vis frame still does not has enough 2D matches") | |
| pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( | |
| "sp+sift+aliked", | |
| cfg.max_query_pts // 2, | |
| track_predictor, | |
| images, | |
| fmaps_for_tracker, | |
| non_vis_frames, | |
| frame_num, | |
| device, | |
| cfg, | |
| ) | |
| # concat predictions | |
| pred_track = torch.cat([pred_track, pred_track_comple], dim=2) | |
| pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) | |
| pred_score = torch.cat([pred_score, pred_score_comple], dim=2) | |
| break | |
| non_vis_query_list = [non_vis_frames[0]] | |
| last_query = non_vis_frames[0] | |
| pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( | |
| cfg.query_method, | |
| cfg.max_query_pts, | |
| track_predictor, | |
| images, | |
| fmaps_for_tracker, | |
| non_vis_query_list, | |
| frame_num, | |
| device, | |
| cfg, | |
| ) | |
| # concat predictions | |
| pred_track = torch.cat([pred_track, pred_track_comple], dim=2) | |
| pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) | |
| pred_score = torch.cat([pred_score, pred_score_comple], dim=2) | |
| non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() | |
| return pred_track, pred_vis, pred_score | |
| def find_query_frame_indexes(reshaped_image, camera_predictor, query_frame_num, image_size=336): | |
| # Downsample image to image_size x image_size | |
| # because we found it is unnecessary to use high resolution | |
| rgbs = F.interpolate(reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True) | |
| rgbs = camera_predictor._resnet_normalize_image(rgbs) | |
| # Get the image features (patch level) | |
| frame_feat = camera_predictor.backbone(rgbs, is_training=True) | |
| frame_feat = frame_feat["x_norm_patchtokens"] | |
| frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) | |
| # Compute the similiarty matrix | |
| frame_feat_norm = frame_feat_norm.permute(1, 0, 2) | |
| similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) | |
| similarity_matrix = similarity_matrix.mean(dim=0) | |
| distance_matrix = 1 - similarity_matrix.clone() | |
| # Ignore self-pairing | |
| similarity_matrix.fill_diagonal_(0) | |
| similarity_sum = similarity_matrix.sum(dim=1) | |
| # Find the most common frame | |
| most_common_frame_index = torch.argmax(similarity_sum).item() | |
| # Conduct FPS sampling | |
| # Starting from the most_common_frame_index, | |
| # try to find the farthest frame, | |
| # then the farthest to the last found frame | |
| # (frames are not allowed to be found twice) | |
| fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) | |
| return fps_idx | |
| def get_query_points(query_image, query_method, max_query_num=4096, det_thres=0.005): | |
| # Run superpoint and sift on the target frame | |
| # Feel free to modify for your own | |
| methods = query_method.split("+") | |
| pred_points = [] | |
| for method in methods: | |
| if "sp" in method: | |
| extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() | |
| elif "sift" in method: | |
| extractor = SIFT(max_num_keypoints=max_query_num).cuda().eval() | |
| elif "aliked" in method: | |
| extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() | |
| else: | |
| raise NotImplementedError(f"query method {method} is not supprted now") | |
| query_points = extractor.extract(query_image)["keypoints"] | |
| pred_points.append(query_points) | |
| query_points = torch.cat(pred_points, dim=1) | |
| if query_points.shape[1] > max_query_num: | |
| random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num] | |
| query_points = query_points[:, random_point_indices, :] | |
| return query_points | |
| def seed_all_random_engines(seed: int) -> None: | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| if __name__ == "__main__": | |
| with torch.no_grad(): | |
| demo_fn() | |