File size: 4,551 Bytes
288376d
 
 
 
 
2df809d
288376d
 
 
2df809d
 
288376d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df809d
288376d
 
 
 
2df809d
288376d
 
 
 
 
2df809d
288376d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df809d
288376d
 
 
 
 
 
 
 
2df809d
288376d
 
 
2df809d
288376d
 
 
 
 
 
 
 
2df809d
288376d
 
 
 
2df809d
288376d
 
2df809d
288376d
 
 
2df809d
288376d
 
 
 
 
 
 
 
 
 
 
2df809d
288376d
 
 
 
2df809d
288376d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# import contextlib
# import os
# import os.path as osp
# import sys
# from typing import cast

# import imageio.v3 as iio
# import numpy as np
# import torch


# class Dust3rPipeline(object):
#     def __init__(self, device: str | torch.device = "cuda"):
#         submodule_path = osp.realpath(
#             osp.join(osp.dirname(__file__), "../../third_party/dust3r/")
#         )
#         if submodule_path not in sys.path:
#             sys.path.insert(0, submodule_path)
#         try:
#             with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
#                 from dust3r.cloud_opt import (  # type: ignore[import]
#                     GlobalAlignerMode,
#                     global_aligner,
#                 )
#                 from dust3r.image_pairs import make_pairs  # type: ignore[import]
#                 from dust3r.inference import inference  # type: ignore[import]
#                 from dust3r.model import AsymmetricCroCo3DStereo  # type: ignore[import]
#                 from dust3r.utils.image import load_images  # type: ignore[import]
#         except ImportError:
#             raise ImportError(
#                 "Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n"
#                 "To initialize them, run the following command in the project root:\n"
#                 "  git submodule update --init --recursive"
#             )

#         self.device = torch.device(device)
#         self.model = AsymmetricCroCo3DStereo.from_pretrained(
#             "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
#         ).to(self.device)

#         self._GlobalAlignerMode = GlobalAlignerMode
#         self._global_aligner = global_aligner
#         self._make_pairs = make_pairs
#         self._inference = inference
#         self._load_images = load_images

#     def infer_cameras_and_points(
#         self,
#         img_paths: list[str],
#         Ks: list[list] = None,
#         c2ws: list[list] = None,
#         batch_size: int = 16,
#         schedule: str = "cosine",
#         lr: float = 0.01,
#         niter: int = 500,
#         min_conf_thr: int = 3,
#     ) -> tuple[
#         list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray]
#     ]:
#         num_img = len(img_paths)
#         if num_img == 1:
#             print("Only one image found, duplicating it to create a stereo pair.")
#             img_paths = img_paths * 2

#         images = self._load_images(img_paths, size=512)
#         pairs = self._make_pairs(
#             images,
#             scene_graph="complete",
#             prefilter=None,
#             symmetrize=True,
#         )
#         output = self._inference(pairs, self.model, self.device, batch_size=batch_size)

#         ori_imgs = [iio.imread(p) for p in img_paths]
#         ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs])
#         img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0)

#         scene = self._global_aligner(
#             output,
#             device=self.device,
#             mode=self._GlobalAlignerMode.PointCloudOptimizer,
#             same_focals=True,
#             optimize_pp=False,  # True,
#             min_conf_thr=min_conf_thr,
#         )

#         # if Ks is not None:
#         #     scene.preset_focal(
#         #         torch.tensor([[K[0, 0], K[1, 1]] for K in Ks])
#         #     )

#         if c2ws is not None:
#             scene.preset_pose(c2ws)

#         _ = scene.compute_global_alignment(
#             init="msp", niter=niter, schedule=schedule, lr=lr
#         )

#         imgs = cast(list, scene.imgs)
#         Ks = scene.get_intrinsics().detach().cpu().numpy().copy()
#         c2ws = scene.get_im_poses().detach().cpu().numpy()  # type: ignore
#         pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()]  # type: ignore
#         if num_img > 1:
#             masks = [x.detach().cpu().numpy() for x in scene.get_masks()]
#             points = [p[m] for p, m in zip(pts3d, masks)]
#             point_colors = [img[m] for img, m in zip(imgs, masks)]
#         else:
#             points = [p.reshape(-1, 3) for p in pts3d]
#             point_colors = [img.reshape(-1, 3) for img in imgs]

#         # Convert back to the original image size.
#         imgs = ori_imgs
#         Ks[:, :2, -1] *= ori_img_whs / img_whs
#         Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None]

#         return imgs, Ks, c2ws, points, point_colors