Spaces:
Running
on
L4
Running
on
L4
File size: 4,551 Bytes
288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d 2df809d 288376d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# import contextlib
# import os
# import os.path as osp
# import sys
# from typing import cast
# import imageio.v3 as iio
# import numpy as np
# import torch
# class Dust3rPipeline(object):
# def __init__(self, device: str | torch.device = "cuda"):
# submodule_path = osp.realpath(
# osp.join(osp.dirname(__file__), "../../third_party/dust3r/")
# )
# if submodule_path not in sys.path:
# sys.path.insert(0, submodule_path)
# try:
# with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
# from dust3r.cloud_opt import ( # type: ignore[import]
# GlobalAlignerMode,
# global_aligner,
# )
# from dust3r.image_pairs import make_pairs # type: ignore[import]
# from dust3r.inference import inference # type: ignore[import]
# from dust3r.model import AsymmetricCroCo3DStereo # type: ignore[import]
# from dust3r.utils.image import load_images # type: ignore[import]
# except ImportError:
# raise ImportError(
# "Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n"
# "To initialize them, run the following command in the project root:\n"
# " git submodule update --init --recursive"
# )
# self.device = torch.device(device)
# self.model = AsymmetricCroCo3DStereo.from_pretrained(
# "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
# ).to(self.device)
# self._GlobalAlignerMode = GlobalAlignerMode
# self._global_aligner = global_aligner
# self._make_pairs = make_pairs
# self._inference = inference
# self._load_images = load_images
# def infer_cameras_and_points(
# self,
# img_paths: list[str],
# Ks: list[list] = None,
# c2ws: list[list] = None,
# batch_size: int = 16,
# schedule: str = "cosine",
# lr: float = 0.01,
# niter: int = 500,
# min_conf_thr: int = 3,
# ) -> tuple[
# list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray]
# ]:
# num_img = len(img_paths)
# if num_img == 1:
# print("Only one image found, duplicating it to create a stereo pair.")
# img_paths = img_paths * 2
# images = self._load_images(img_paths, size=512)
# pairs = self._make_pairs(
# images,
# scene_graph="complete",
# prefilter=None,
# symmetrize=True,
# )
# output = self._inference(pairs, self.model, self.device, batch_size=batch_size)
# ori_imgs = [iio.imread(p) for p in img_paths]
# ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs])
# img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0)
# scene = self._global_aligner(
# output,
# device=self.device,
# mode=self._GlobalAlignerMode.PointCloudOptimizer,
# same_focals=True,
# optimize_pp=False, # True,
# min_conf_thr=min_conf_thr,
# )
# # if Ks is not None:
# # scene.preset_focal(
# # torch.tensor([[K[0, 0], K[1, 1]] for K in Ks])
# # )
# if c2ws is not None:
# scene.preset_pose(c2ws)
# _ = scene.compute_global_alignment(
# init="msp", niter=niter, schedule=schedule, lr=lr
# )
# imgs = cast(list, scene.imgs)
# Ks = scene.get_intrinsics().detach().cpu().numpy().copy()
# c2ws = scene.get_im_poses().detach().cpu().numpy() # type: ignore
# pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()] # type: ignore
# if num_img > 1:
# masks = [x.detach().cpu().numpy() for x in scene.get_masks()]
# points = [p[m] for p, m in zip(pts3d, masks)]
# point_colors = [img[m] for img, m in zip(imgs, masks)]
# else:
# points = [p.reshape(-1, 3) for p in pts3d]
# point_colors = [img.reshape(-1, 3) for img in imgs]
# # Convert back to the original image size.
# imgs = ori_imgs
# Ks[:, :2, -1] *= ori_img_whs / img_whs
# Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None]
# return imgs, Ks, c2ws, points, point_colors
|