bartduis commited on
Commit
70d1188
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.out
2
+ slurm/
3
+ *.pyc
4
+ *.png
5
+ !assets/*.png
6
+ *.mtl
7
+ *.obj
8
+ *.ply
9
+ *.pth
10
+ **/build/**
11
+ *.so
12
+ wandb**
13
+ logs
LICENSE ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RaySt3R
2
+ SOFTWARE LICENSE AGREEMENT
3
+ ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH
4
+ USE ONLY
5
+ BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO
6
+ THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH
7
+ THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
8
+
9
+ This is a license agreement ("Agreement") between your academic institution or non-
10
+ profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie
11
+
12
+ Mellon University (called "Licensor" in this Agreement). All rights not specifically
13
+ granted to you in this Agreement are reserved for Licensor.
14
+ RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
15
+ Licensor retains exclusive ownership of any copy of the Software (as defined below)
16
+ licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
17
+ non-transferable license to use the Software for noncommercial research purposes,
18
+ without the right to sublicense, pursuant to the terms and conditions of this Agreement.
19
+ As used in this Agreement, the term "Software" means (i) the actual copy of all or any
20
+ portion of code for program routines made accessible to Licensee by Licensor pursuant to
21
+ this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder
22
+ or subsequently supplied by Licensor, including all or any file structures, programming
23
+ instructions, user interfaces and screen formats and sequences as well as any and all
24
+ documentation and instructions related to it, and (ii) all or any derivatives and/or
25
+ modifications created or made by You to any of the items specified in (i).
26
+ CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to
27
+ Licensor, and as such, Licensee agrees to receive all such materials in confidence and use
28
+ the Software only in accordance with the terms of this Agreement. Licensee agrees to
29
+ use reasonable effort to protect the Software from unauthorized use, reproduction,
30
+ distribution, or publication.
31
+ COPYRIGHT: The Software is owned by Licensor and is protected by United
32
+ States copyright laws and applicable international treaties and/or conventions.
33
+ PERMITTED USES: The Software may be used for your own noncommercial internal
34
+ research purposes. You understand and agree that Licensor is not obligated to implement
35
+ any suggestions and/or feedback you might provide regarding the Software, but to the
36
+ extent Licensor does so, you are not entitled to any compensation related thereto.
37
+ DERIVATIVES: You may create derivatives of or make modifications to the Software,
38
+ however, You agree that all and any such derivatives and modifications will be owned by
39
+ Licensor and become a part of the Software licensed to You under this Agreement. You
40
+ may only use such derivatives and modifications for your own noncommercial internal
41
+ research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. You must provide to Licensor one copy
42
+ of all such derivatives and modifications in a recognized electronic format by way of
43
+ electronic mail sent to Bardienus Pieter Duisterhof at [email protected]
44
+ within thirty (30) days of the publication date of any publication that relates to any such
45
+ derivatives or modifications. You understand that Licensor is not obligated to distribute
46
+ or otherwise make available any derivatives or modifications provided by You.
47
+ BACKUPS: If Licensee is an organization, it may make that number of copies of the
48
+ Software necessary for internal noncommercial use at a single site within its organization
49
+ provided that all information appearing in or on the original labels, including the
50
+ copyright and trademark notices are copied onto the labels of the copies.
51
+ USES NOT PERMITTED: You may not distribute, copy or use the Software except as
52
+ explicitly permitted herein. Licensee has not been granted any trademark license as part
53
+ of this Agreement and may not use the name or mark "RaySt3R" "Carnegie Mellon" or any renditions thereof without the prior written
54
+ permission of Licensor.
55
+ You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part,
56
+ or provide third parties access to prior or present versions (or any parts thereof) of the
57
+ Software.
58
+ ASSIGNMENT: You may not assign this Agreement or your rights hereunder without
59
+ the prior written consent of Licensor. Any attempted assignment without such consent
60
+ shall be null and void.
61
+ TERM: The term of the license granted by this Agreement is from Licensee's acceptance
62
+ of this Agreement by clicking "I Agree" below or by using the Software until terminated
63
+ as provided below.
64
+ The Agreement automatically terminates without notice if you fail to comply with any
65
+ provision of this Agreement. Licensee may terminate this Agreement by ceasing using
66
+ the Software. Upon any termination of this Agreement, Licensee will delete any and all
67
+ copies of the Software. You agree that all provisions which operate to protect the
68
+ proprietary rights of Licensor shall remain in force should breach occur and that the
69
+ obligation of confidentiality described in this Agreement is binding in perpetuity and, as
70
+ such, survives the term of the Agreement.
71
+ FEE: Provided Licensee abides completely by the terms and conditions of this
72
+ Agreement, there is no fee due to Licensor for Licensee's use of the Software in
73
+ accordance with this Agreement.
74
+ DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS"
75
+ WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF
76
+ PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR
77
+ USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK
78
+
79
+ RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND
80
+ RELATED MATERIALS.
81
+ SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is
82
+ provided as part of this Agreement.
83
+ EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent
84
+ permitted under applicable law, Licensor shall not be liable for direct, indirect, special,
85
+ incidental, or consequential damages or lost profits related to Licensee's use of and/or
86
+ inability to use the Software, even if Licensor is advised of the possibility of such
87
+ damage.
88
+ EXPORT REGULATION: You agree to comply with any and all applicable U.S. export
89
+ control laws, regulations, and/or other laws related to the embargoes and sanction
90
+ programs administered by the U.S. Office of Foreign Assets Control. You may not export
91
+ or re-export the technology with individuals or companies on the U.S. Department of
92
+ Commerce, Department of State or Department of Treasury denied party lists
93
+ https://www.trade.gov/consolidated-screening-list . You represent and warrant that
94
+ Licensee is not an individual or company listed on such denied party lists.
95
+ SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid,
96
+ illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the
97
+ validity, legality and enforceability of the remaining provisions shall not in any way be
98
+ affected or impaired thereby.
99
+ NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or
100
+ remedy under this Agreement shall be construed as a waiver of any future or other
101
+ exercise of such right or remedy by Licensor.
102
+ GOVERNING LAW: This Agreement shall be construed and enforced in accordance
103
+ with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws
104
+ principles. You consent to the personal jurisdiction of the courts of this County and
105
+ waive their rights to venue outside of Allegheny County, Pennsylvania.
106
+ ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole
107
+ and entire agreement between Licensee and Licensor as to the matter set forth herein and
108
+ supersedes any previous agreements, understandings, and arrangements between the
109
+ parties relating hereto.
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import rembg
5
+ import trimesh
6
+ from moge.model.v1 import MoGeModel
7
+ from utils.geometry import compute_pointmap
8
+ import os, shutil
9
+ import cv2
10
+ from huggingface_hub import hf_hub_download
11
+ from PIL import Image
12
+ import matplotlib.pyplot as plt
13
+ from eval_wrapper.eval import EvalWrapper, eval_scene
14
+ from torchvision import transforms
15
+
16
+ outdir = "/tmp/rayst3r"
17
+
18
+ # loading all necessary models
19
+ print("Loading DINOv2 model")
20
+ dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
21
+ dino_model.eval()
22
+ dino_model.to("cuda")
23
+
24
+ print("Loading MoGe model")
25
+ device = torch.device("cuda")
26
+ # Load the model from huggingface hub (or load from local).
27
+ moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device)
28
+
29
+ print("Loading RaySt3R model")
30
+ rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
31
+ rayst3r_model = EvalWrapper(rayst3r_checkpoint)
32
+
33
+ def depth2uint16(depth):
34
+ return depth * torch.iinfo(torch.uint16).max / 10.0 # threshold is in m, convert to uint16 value
35
+
36
+ def save_tensor_as_png(tensor: torch.Tensor, path: str, dtype: torch.dtype | None = None):
37
+ if dtype is None:
38
+ dtype = tensor.dtype
39
+ Image.fromarray(tensor.to(dtype).cpu().numpy()).save(path)
40
+
41
+ def colorize_points_with_turbo_all_dims(points, method='norm',cmap='turbo'):
42
+ """
43
+ Assigns colors to 3D points using the 'turbo' colormap based on a scalar computed from all 3 dimensions.
44
+
45
+ Args:
46
+ points (np.ndarray): (N, 3) array of 3D points.
47
+ method (str): Method for reducing 3D point to scalar. Options: 'norm', 'pca'.
48
+
49
+ Returns:
50
+ np.ndarray: (N, 3) RGB colors in [0, 1].
51
+ """
52
+ assert points.shape[1] == 3, "Input must be of shape (N, 3)"
53
+
54
+ if method == 'norm':
55
+ scalar = np.linalg.norm(points, axis=1)
56
+ elif method == 'pca':
57
+ # Project onto first principal component
58
+ mean = points.mean(axis=0)
59
+ centered = points - mean
60
+ u, s, vh = np.linalg.svd(centered, full_matrices=False)
61
+ scalar = centered @ vh[0] # Project onto first principal axis
62
+ else:
63
+ raise ValueError(f"Unknown method '{method}'")
64
+
65
+ # Normalize scalar to [0, 1]
66
+ scalar_min, scalar_max = scalar.min(), scalar.max()
67
+ normalized = (scalar - scalar_min) / (scalar_max - scalar_min + 1e-8)
68
+
69
+ # Apply turbo colormap
70
+ cmap = plt.colormaps.get_cmap(cmap)
71
+ colors = cmap(normalized)[:, :3] # Drop alpha
72
+
73
+ return colors
74
+
75
+ def prep_for_rayst3r(img,depth_dict,mask):
76
+ H, W = img.shape[:2]
77
+ intrinsics = depth_dict["intrinsics"].detach().cpu()
78
+ intrinsics[0] *= W
79
+ intrinsics[1] *= H
80
+
81
+ input_dir = os.path.join(outdir, "input")
82
+ if os.path.exists(input_dir):
83
+ shutil.rmtree(input_dir)
84
+ os.makedirs(input_dir, exist_ok=True)
85
+ # save intrinsics
86
+ torch.save(intrinsics, os.path.join(input_dir, "intrinsics.pt"))
87
+
88
+ # save depth
89
+ depth = depth_dict["depth"].cpu()
90
+ depth = depth2uint16(depth)
91
+ save_tensor_as_png(depth, os.path.join(input_dir, "depth.png"),dtype=torch.uint16)
92
+
93
+ # save mask as bool
94
+ save_tensor_as_png(torch.from_numpy(mask).bool(), os.path.join(input_dir, "mask.png"),dtype=torch.bool)
95
+ # save image
96
+ save_tensor_as_png(torch.from_numpy(img), os.path.join(input_dir, "rgb.png"))
97
+
98
+ def rayst3r_to_glb(img,depth_dict,mask,max_total_points=10e6,rotated=False):
99
+ prep_for_rayst3r(img,depth_dict,mask)
100
+ rayst3r_points = eval_scene(rayst3r_model,os.path.join(outdir, "input"),do_filter_all_masks=True,dino_model=dino_model).cpu()
101
+
102
+ # subsample points
103
+ n_points = min(max_total_points,rayst3r_points.shape[0])
104
+ rayst3r_points = rayst3r_points[torch.randperm(rayst3r_points.shape[0])[:n_points]].numpy()
105
+
106
+ rayst3r_points[:,1] = -rayst3r_points[:,1]
107
+ rayst3r_points[:,2] = -rayst3r_points[:,2]
108
+
109
+ # make all points red
110
+ colors = colorize_points_with_turbo_all_dims(rayst3r_points)
111
+
112
+ # load the input glb
113
+ scene = trimesh.Scene()
114
+ pct = trimesh.PointCloud(rayst3r_points, colors=colors, radius=0.01)
115
+ scene.add_geometry(pct)
116
+
117
+ outfile = os.path.join(outdir, "rayst3r.glb")
118
+ scene.export(outfile)
119
+ return outfile
120
+
121
+
122
+ def input_to_glb(outdir,img,depth_dict,mask,rotated=False):
123
+ H, W = img.shape[:2]
124
+ intrinsics = depth_dict["intrinsics"].cpu().numpy()
125
+ intrinsics[0] *= W
126
+ intrinsics[1] *= H
127
+
128
+ depth = depth_dict["depth"].cpu().numpy()
129
+ cam2world = np.eye(4)
130
+ points_world = compute_pointmap(depth, cam2world, intrinsics)
131
+
132
+ scene = trimesh.Scene()
133
+ pts = np.concatenate([p[m] for p,m in zip(points_world,mask)])
134
+ col = np.concatenate([c[m] for c,m in zip(img,mask)])
135
+
136
+ pts = pts.reshape(-1,3)
137
+ pts[:,1] = -pts[:,1]
138
+ pts[:,2] = -pts[:,2]
139
+
140
+
141
+ pct = trimesh.PointCloud(pts, colors=col.reshape(-1,3))
142
+ scene.add_geometry(pct)
143
+
144
+ outfile = os.path.join(outdir, "input.glb")
145
+ scene.export(outfile)
146
+ return outfile
147
+
148
+ def depth_moge(input_img):
149
+ input_img_torch = torch.tensor(input_img / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
150
+ output = moge_model.infer(input_img_torch)
151
+ return output
152
+
153
+ def mask_rembg(input_img):
154
+ #masked_img = rembg.remove(input_img,)
155
+ output_img = rembg.remove(input_img, alpha_matting=False, post_process_mask=True)
156
+
157
+ # Convert to NumPy array
158
+ output_np = np.array(output_img)
159
+ alpha = output_np[..., 3]
160
+
161
+ # Step 2: Erode the alpha mask to shrink object slightly
162
+ kernel = np.ones((3, 3), np.uint8) # Adjust size for aggressiveness
163
+ eroded_alpha = cv2.erode(alpha, kernel, iterations=1)
164
+ # Step 3: Replace alpha channel
165
+ output_np[..., 3] = eroded_alpha
166
+
167
+ mask = output_np[:,:,-1] >= 128
168
+ rgb = output_np[:,:,:3]
169
+ return mask, rgb
170
+
171
+ def process_image(input_img):
172
+ # resize the input image
173
+ rotated = False
174
+ #if input_img.shape[0] > input_img.shape[1]:
175
+ #input_img = cv2.rotate(input_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
176
+ #rotated = True
177
+ input_img = cv2.resize(input_img, (640, 480))
178
+ mask, rgb = mask_rembg(input_img)
179
+ depth_dict = depth_moge(input_img)
180
+
181
+ if os.path.exists(outdir):
182
+ shutil.rmtree(outdir)
183
+ os.makedirs(outdir)
184
+
185
+ input_glb = input_to_glb(outdir,input_img,depth_dict,mask,rotated=rotated)
186
+
187
+ # visualize the input points in 3D in gradio
188
+ inference_glb = rayst3r_to_glb(input_img,depth_dict,mask,rotated=rotated)
189
+
190
+ return input_glb, inference_glb
191
+
192
+ demo = gr.Interface(
193
+ process_image,
194
+ gr.Image(),
195
+ [gr.Model3D(label="Input"), gr.Model3D(label="RaySt3R",)]
196
+ )
197
+
198
+ if __name__ == "__main__":
199
+ demo.launch()
datasets/generic_loader.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bb = breakpoint
2
+ import torch
3
+ import trimesh
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+ from pathlib import Path
8
+ import pickle
9
+ import tqdm
10
+ import json
11
+ from PIL import Image
12
+
13
+ class GenericLoader(torch.utils.data.Dataset):
14
+ def __init__(self,dir="octmae_data/tiny_train/train_processed",seed=747,size=10,datasets=["fp_objaverse"],split="train",dtype=torch.float32,mode="slow",
15
+ prefetch_dino=False,dino_features=[23],view_select_mode="new_zoom",noise_std=0.0,rendered_views_mode="None",**kwargs):
16
+ super().__init__(**kwargs)
17
+ self.dir = dir
18
+ self.rng = np.random.default_rng(seed)
19
+ self.size = size
20
+ self.datasets = datasets
21
+ self.split = split
22
+ self.dtype = dtype
23
+ self.mode = mode
24
+ self.prefetch_dino = prefetch_dino
25
+ self.view_select_mode = view_select_mode
26
+ self.noise_std = noise_std * torch.iinfo(torch.uint16).max / 10.0 # variance in the range of the depth map, uint16 normalized to 10
27
+ if self.mode == 'slow':
28
+ self.prefetch_dino = True
29
+ self.find_scenes()
30
+ self.dino_features = dino_features
31
+ self.rendered_views_mode = rendered_views_mode
32
+
33
+ def find_dataset_location_list(self,dataset):
34
+ data_dir = None
35
+ for d in self.dir:
36
+ datasets = os.listdir(d)
37
+ if dataset in datasets:
38
+ if data_dir is not None:
39
+ raise ValueError(f"Dataset {dataset} found in multiple locations: {self.dir}")
40
+ else:
41
+ data_dir = os.path.join(d,dataset)
42
+ if data_dir is None:
43
+ raise ValueError(f"Dataset {dataset} not found in {self.dir}")
44
+ return data_dir
45
+
46
+ def find_dataset_location(self,dataset):
47
+ if isinstance(self.dir,list):
48
+ data_dir = self.find_dataset_location_list(dataset)
49
+ else:
50
+ data_dir = os.path.join(self.dir,dataset)
51
+ if not os.path.exists(data_dir):
52
+ raise ValueError(f"Dataset {dataset} not found in {self.dir}")
53
+ return data_dir
54
+
55
+ def find_scenes(self):
56
+ all_scenes = {}
57
+ print("Loading scenes...")
58
+ for dataset in self.datasets:
59
+ dataset_dir = self.find_dataset_location(dataset)
60
+ scenes = json.load(open(os.path.join(dataset_dir, f"{self.split}_scenes.json")))
61
+ scene_ids = [dataset + "_" + f.split("/")[-2] + "_" + f.split("/")[-1] for f in scenes]
62
+ all_scenes.update(dict(zip(scene_ids, scenes)))
63
+ self.scenes = all_scenes
64
+ self.scene_ids = list(self.scenes.keys())
65
+ # shuffle the scene ids
66
+ self.rng.shuffle(self.scene_ids)
67
+ if self.size > 0:
68
+ self.scene_ids = self.scene_ids[:self.size]
69
+ self.size = len(self.scene_ids)
70
+ return scenes
71
+
72
+ def __len__(self):
73
+ return self.size
74
+
75
+ def decide_context_view(self,cam_dir):
76
+ # we pick the view furthest away from the origin as the view for conditioning
77
+ cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and not d.startswith("gen")] # input cam needs rgb
78
+
79
+ extrinsics = {c:torch.load(os.path.join(cam_dir,c,'cam2world.pt'),map_location='cpu',weights_only=True) for c in cam_dirs}
80
+ dist_origin = {c:torch.linalg.norm(extrinsics[c][:3,3]) for c in extrinsics}
81
+
82
+ if self.view_select_mode == 'new_zoom':
83
+ # find the view with the maximum distance to the origin
84
+ input_cam = max(dist_origin,key=dist_origin.get)
85
+ # pick another random view to predict, excluding the context view
86
+ elif self.view_select_mode == 'random':
87
+ # pick a random view
88
+ input_cam = str(self.rng.choice(list(dist_origin.keys())))
89
+ # pick another random view to predict, excluding the context view
90
+ else:
91
+ raise ValueError(f"Invalid mode: {self.view_select_mode}")
92
+
93
+ if self.rendered_views_mode == "None":
94
+ pass
95
+ elif self.rendered_views_mode == "random":
96
+ cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d))]
97
+ elif self.rendered_views_mode == "always":
98
+ cam_dirs_gen = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and d.startswith("gen")]
99
+ if len(cam_dirs_gen) > 0:
100
+ cam_dirs = cam_dirs_gen
101
+ else:
102
+ raise ValueError(f"Invalid mode: {self.rendered_views_mode}")
103
+
104
+ possible_views = [v for v in cam_dirs if v != input_cam]
105
+ new_cam = str(self.rng.choice(possible_views))
106
+ return input_cam,new_cam
107
+
108
+ def transform_pointmap(self,pointmap_cam,c2w):
109
+ # pointmap: shape H x W x 3
110
+ # cw2: shape 4 x 4
111
+ # we want to transform the pointmap to the world frame
112
+ pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
113
+ pointmap_world_h = pointmap_cam_h @ c2w.T
114
+ pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
115
+ return pointmap_world
116
+
117
+ def load_scene_slow(self,input_cam,new_cam,cam_dir):
118
+
119
+ data = dict(new_cams={},input_cams={})
120
+
121
+ data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
122
+ data['new_cams']['depths'] = [torch.load(os.path.join(cam_dir,new_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
123
+ data['new_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,new_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['new_cams']['c2ws'][0])]
124
+ data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
125
+ data['new_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,new_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)]
126
+
127
+ # add the context views
128
+ data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
129
+ data['input_cams']['depths'] = [torch.load(os.path.join(cam_dir,input_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
130
+ data['input_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,input_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['input_cams']['c2ws'][0])]
131
+ data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
132
+ data['input_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,input_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)]
133
+ data['input_cams']['imgs'] = [torch.load(os.path.join(cam_dir,input_cam,'rgb.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
134
+ data['input_cams']['dino_features'] = [torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features]
135
+ return data
136
+
137
+ def depth_to_metric(self,depth):
138
+ # depth: shape H x W
139
+ # we want to convert the depth to a metric depth
140
+ depth_max = 10.0
141
+ depth_scaled = depth_max * (depth / 65535.0)
142
+ return depth_scaled
143
+
144
+ def load_scene_fast(self,input_cam,new_cam,cam_dir):
145
+ data = dict(new_cams={},input_cams={})
146
+ data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
147
+ data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
148
+ data['new_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'depth.png'))).astype(np.float32))]
149
+ data['new_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'mask.png'))))]
150
+
151
+ data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
152
+ data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
153
+ data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'depth.png'))).astype(np.float32))]
154
+ data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'mask.png'))))]
155
+ data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'rgb.png'))))]
156
+ if self.prefetch_dino:
157
+ data['input_cams']['dino_features'] = [torch.cat([torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features],dim=-1)]
158
+ return data
159
+
160
+ def __getitem__(self,idx):
161
+ cam_dir = os.path.join(self.scenes[self.scene_ids[idx]],'cameras')
162
+ #data['input_cams'] = {k:[v[0].unsqueeze(0)] for k,v in data['input_cams'].items()}
163
+ input_cam,new_cam = self.decide_context_view(cam_dir)
164
+ if self.mode == 'slow':
165
+ data = self.load_scene_slow(input_cam,new_cam,cam_dir)
166
+ else:
167
+ data = self.load_scene_fast(input_cam,new_cam,cam_dir)
168
+ return data
engine.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bb=breakpoint
2
+ import torch
3
+ from utils.geometry import center_pointmaps, uncenter_pointmaps
4
+ from utils.utils import scenes_to_batch, batch_to_scenes
5
+ from utils.batch_prep import prepare_fast_batch, normalize_batch, denormalize_batch
6
+ from utils.viz import save_pointmaps
7
+ from tqdm import tqdm
8
+ import wandb
9
+ from utils import misc
10
+ from torch.amp import GradScaler
11
+ from utils.eval import eval_pred
12
+ from utils.geometry import depth2pts
13
+
14
+ def batch_to_device(batch,device='cuda'):
15
+ for key in batch:
16
+ if isinstance(batch[key],torch.Tensor):
17
+ batch[key] = batch[key].to(device)
18
+ elif isinstance(batch[key],dict):
19
+ batch[key] = batch_to_device(batch[key],device)
20
+ return batch
21
+
22
+ def eval_model(model,batch,mode='loss',device='cuda',dino_model=None,args=None,augmentor=None,return_scale=False):
23
+ batch = batch_to_device(batch,device)
24
+ # check if model is distributed
25
+ if isinstance(model,torch.nn.parallel.DistributedDataParallel):
26
+ dino_layers = model.module.dino_layers
27
+ else:
28
+ dino_layers = model.dino_layers
29
+ if 'pointmaps' not in list(batch['input_cams'].keys()):
30
+ batch = prepare_fast_batch(batch,dino_model,dino_layers)
31
+
32
+ normalize_mode = args.normalize_mode if args is not None else 'median'
33
+ batch, scale_factors = normalize_batch(batch,normalize_mode)
34
+ if augmentor is not None:
35
+ batch = augmentor(batch)
36
+
37
+ batch, n_cams = scenes_to_batch(batch)
38
+ batch = center_pointmaps(batch) # centering around first camera
39
+
40
+ device = args.device if args is not None else 'cuda'
41
+ with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
42
+ pred, gt, loss_dict = model(batch,mode='viz')
43
+
44
+ if 'pointmaps' not in list(pred.keys()):
45
+ pred['pointmaps'] = depth2pts(pred['depths'].squeeze(-1),batch['new_cams']['Ks'])
46
+ elif 'depths' not in list(pred.keys()):
47
+ pred['depths'] = pred['pointmaps'][...,-1]
48
+ loss_dict = {**loss_dict,**eval_pred(pred, gt)}
49
+ if mode == 'loss':
50
+ return loss_dict
51
+ elif mode == 'viz':
52
+ pred, gt, batch = uncenter_pointmaps(pred, gt, batch)
53
+ pred, gt, batch = batch_to_scenes(pred, gt,batch, n_cams)
54
+ if return_scale:
55
+ return pred, gt, loss_dict, scale_factors[0].item()
56
+ else:
57
+ return pred, gt, loss_dict
58
+ else:
59
+ raise ValueError(f"Invalid mode: {mode}")
60
+
61
+ def update_loss_dict(loss_dict,loss_dict_new,sample_count):
62
+ for key in loss_dict_new:
63
+ if key not in loss_dict:
64
+ loss_dict[key] = loss_dict_new[key]
65
+ else:
66
+ # previously stored value in loss_dict is average from sample_count samples
67
+ # so we need to update it to include the new sample
68
+ loss_dict[key] = (loss_dict[key] * sample_count + loss_dict_new[key]) / (sample_count + 1)
69
+ return loss_dict
70
+
71
+ def train_epoch(model, train_loader, optimizer, device='cuda', max_norm=1.0,log_wandb=False,epoch=0,batch_size=None,args=None,dino_model=None,augmentor=None):
72
+ model.train()
73
+ all_losses_dict = {}
74
+
75
+ sample_idx = epoch * batch_size * len(train_loader)
76
+ scaler = GradScaler()
77
+ for i, batch in tqdm(enumerate(train_loader),total=len(train_loader)):
78
+ optimizer.zero_grad()
79
+ new_loss_dict = eval_model(model, batch, mode='loss', device=device,dino_model=dino_model,args=args,augmentor=augmentor)
80
+ loss = new_loss_dict['loss']
81
+ if loss is None:
82
+ continue
83
+
84
+ scaler.scale(loss).backward()
85
+ # Unscales the gradients of optimizer's assigned params in-place
86
+ scaler.unscale_(optimizer)
87
+
88
+ grad_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None]))
89
+ if grad_norm.isnan():
90
+ breakpoint()
91
+
92
+ ## Since the gradients of optimizer's assigned params are unscaled, clips as usual:
93
+ if max_norm > 0:
94
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
95
+
96
+ # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
97
+ # although it still skips optimizer.step() if the gradients contain infs or NaNs.
98
+ scaler.step(optimizer)
99
+
100
+ # Updates the scale for next iteration.
101
+ scaler.update()
102
+
103
+ new_loss_dict['grad_norm'] = grad_norm.detach().cpu().item()
104
+
105
+ misc.adjust_learning_rate(optimizer, epoch + i/len(train_loader), args)
106
+ optimizer.step()
107
+
108
+ new_loss_dict = {k: (v.detach().cpu().item() if isinstance(v, torch.Tensor) else v) for k, v in new_loss_dict.items()}
109
+ if log_wandb:
110
+ wandb_dict = {f"train_{k}":v for k,v in new_loss_dict.items()}
111
+ wandb.log(wandb_dict, step=sample_idx + (i+1)*batch_size)
112
+ # log learning rate
113
+ wandb.log({"train_lr": optimizer.param_groups[0]['lr']}, step=sample_idx + (i+1)*batch_size)
114
+
115
+ all_losses_dict = update_loss_dict(all_losses_dict, new_loss_dict,sample_count=i)
116
+ # Clear cache and delete variables to free memory
117
+ torch.cuda.empty_cache()
118
+ del loss
119
+ del new_loss_dict
120
+ del grad_norm
121
+ del batch
122
+
123
+ return all_losses_dict
124
+
125
+ def eval_epoch(model,test_loader,device='cuda',dino_model=None,args=None,augmentor=None):
126
+ model.eval()
127
+ all_losses_dict = {}
128
+ with torch.no_grad():
129
+ for i, batch in tqdm(enumerate(test_loader),total=len(test_loader)):
130
+ new_loss_dict = eval_model(model,batch,mode='loss',device=device,dino_model=dino_model,args=args,augmentor=augmentor)
131
+ if new_loss_dict is None:
132
+ continue
133
+ all_losses_dict = update_loss_dict(all_losses_dict,new_loss_dict,sample_count=i)
134
+
135
+ torch.cuda.empty_cache()
136
+ del new_loss_dict
137
+ del batch
138
+
139
+ return all_losses_dict
eval_wrapper/eval.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ import os
7
+ import sys
8
+ import open3d as o3d
9
+ current_dir = os.getcwd()
10
+ sys.path.append(current_dir)
11
+
12
+ from eval_wrapper.sample_poses import pointmap_to_poses
13
+ from utils.fusion import fuse_batch
14
+ from models.rayquery import *
15
+ from models.losses import *
16
+ import argparse
17
+ from utils import misc
18
+ import torch.distributed as dist
19
+ from utils.collate import collate
20
+ from engine import eval_model
21
+ from utils.viz import just_load_viz
22
+ from utils.geometry import compute_pointmap_torch
23
+ from eval_wrapper.eval_utils import npy2ply, filter_all_masks
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ class EvalWrapper(torch.nn.Module):
27
+ def __init__(self,checkpoint_path,distributed=False,device="cuda",dtype=torch.float32,**kwargs):
28
+ super().__init__()
29
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
30
+ model_string = checkpoint['args'].model
31
+
32
+ self.model = eval(model_string).to(device)
33
+ if distributed:
34
+ rank, world_size, local_rank = misc.setup_distributed()
35
+ self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank],find_unused_parameters=True)
36
+
37
+ self.dtype = dtype
38
+ self.model.load_state_dict(checkpoint['model'])
39
+ self.model.eval()
40
+
41
+ def forward(self,x,dino_model=None):
42
+ pred, gt, loss, scale = eval_model(self.model,x,mode='viz',dino_model=dino_model,return_scale=True)
43
+ return pred, gt, loss, scale
44
+
45
+ class PostProcessWrapper(torch.nn.Module):
46
+ def __init__(self,pred_mask_threshold = 0.5, mode='novel_views',
47
+ debug=False,conf_dist_mode='isotonic',set_conf=None,percentile=20,
48
+ no_input_mask=False,no_pred_mask=False):
49
+ super().__init__()
50
+ self.pred_mask_threshold = pred_mask_threshold
51
+ self.mode = mode
52
+ self.debug = debug
53
+ self.conf_dist_mode = conf_dist_mode
54
+ self.set_conf = set_conf
55
+ self.percentile = percentile
56
+ self.no_input_mask = no_input_mask
57
+ self.no_pred_mask = no_pred_mask
58
+
59
+ def transform_pointmap(self,pointmap_cam,c2w):
60
+ # pointmap: shape H x W x 3
61
+ # cw2: shape 4 x 4
62
+ # we want to transform the pointmap to the world frame
63
+ pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
64
+ pointmap_world_h = pointmap_cam_h @ c2w.T
65
+ pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
66
+ return pointmap_world
67
+
68
+ def reject_conf_points(self,conf_pts):
69
+ if self.set_conf is None:
70
+ raise ValueError("set_conf must be set")
71
+
72
+ conf_mask = conf_pts > self.set_conf
73
+ return conf_mask
74
+
75
+
76
+ def project_input_mask(self,pred_dict,batch):
77
+ input_mask = batch['input_cams']['original_valid_masks'][0][0] # shape H x W
78
+ input_c2w = batch['input_cams']['c2ws'][0][0]
79
+ input_w2c = torch.linalg.inv(input_c2w)
80
+ input_K = batch['input_cams']['Ks'][0][0]
81
+ H, W = input_mask.shape
82
+ pointmaps_input_cam = torch.stack([self.transform_pointmap(pmap,input_w2c@c2w) for pmap,c2w in zip(pred_dict['pointmaps'][0],batch['new_cams']['c2ws'][0])]) # bp: Assuming batch size is 1!!
83
+ img_coords = pointmaps_input_cam @ input_K.T
84
+ img_coords = (img_coords[...,:2]/img_coords[...,2:3]).int()
85
+
86
+ n_views, H, W = img_coords.shape[:3]
87
+ device = input_mask.device
88
+ if self.no_input_mask:
89
+ combined_mask = torch.ones((n_views, H, W), device=device)
90
+ else:
91
+ combined_mask = torch.zeros((n_views, H, W), device=device)
92
+
93
+ # Flatten spatial dims
94
+ xs = img_coords[..., 0].view(n_views, -1) # [V, H*W]
95
+ ys = img_coords[..., 1].view(n_views, -1) # [V, H*W]
96
+
97
+ # Create base pixel coords (i, j)
98
+ i_coords = torch.arange(H, device=device).view(-1, 1).expand(H, W).reshape(-1) # [H*W]
99
+ j_coords = torch.arange(W, device=device).view(1, -1).expand(H, W).reshape(-1) # [H*W]
100
+ mask_coords = torch.stack((i_coords, j_coords), dim=-1) # [H*W, 2], shared across views
101
+
102
+ # Mask for valid projections
103
+ valid = (xs >= 0) & (xs < W) & (ys >= 0) & (ys < H) # [V, H*W]
104
+
105
+ # Clip out-of-bounds coords for indexing (only valid will be used anyway)
106
+ xs_clipped = torch.clamp(xs, 0, W-1)
107
+ ys_clipped = torch.clamp(ys, 0, H-1)
108
+
109
+ # input_mask lookup per view
110
+ flat_input_mask = input_mask[ys_clipped, xs_clipped] # [V, H*W]
111
+ input_mask_mask = flat_input_mask & valid # apply valid range mask
112
+
113
+ # Apply mask to coords and depths
114
+ depth_points = pointmaps_input_cam[..., -1].view(n_views, -1) # [V, H*W]
115
+ input_depths = batch['input_cams']['depths'][0][0][ys_clipped, xs_clipped] # [V, H*W]
116
+
117
+ depth_mask = (depth_points > input_depths) & input_mask_mask # final mask [V, H*W]
118
+ #depth_mask = input_mask_mask # final mask [V, H*W]
119
+
120
+ # Get final (i,j) coords to write
121
+ final_i = mask_coords[:, 0].unsqueeze(0).expand(n_views, -1)[depth_mask] # [N_mask]
122
+ final_j = mask_coords[:, 1].unsqueeze(0).expand(n_views, -1)[depth_mask] # [N_mask]
123
+ final_view_idx = torch.arange(n_views, device=device).view(-1, 1).expand(-1, H*W)[depth_mask] # [N_mask]
124
+
125
+ # Scatter final mask
126
+ combined_mask[final_view_idx, final_i, final_j] = 1
127
+ return combined_mask.unsqueeze(0).bool()
128
+
129
+ def forward(self,pred_dict,batch):
130
+ if self.mode == 'novel_views':
131
+ project_masks = self.project_input_mask(pred_dict,batch)
132
+ pred_mask_raw = torch.sigmoid(pred_dict['classifier'])
133
+ if self.no_pred_mask:
134
+ pred_masks = torch.ones_like(project_masks).bool()
135
+ else:
136
+ pred_masks = (pred_mask_raw > self.pred_mask_threshold).bool()
137
+
138
+ conf_masks = self.reject_conf_points(pred_dict['conf_pointmaps'])
139
+ combined_mask = project_masks & pred_masks & conf_masks
140
+ batch['new_cams']['valid_masks'] = combined_mask
141
+
142
+ elif self.mode == 'input_view':
143
+ conf_masks = self.reject_conf_points(pred_dict['conf_pointmaps'])
144
+ if self.no_pred_mask:
145
+ pred_masks = torch.ones_like(conf_masks).bool()
146
+ else:
147
+ pred_mask_raw = torch.sigmoid(pred_dict['classifier'])
148
+ pred_masks = (pred_mask_raw > self.pred_mask_threshold).bool()
149
+ combined_mask = conf_masks & batch['new_cams']['valid_masks'] & pred_masks
150
+ batch['new_cams']['valid_masks'] = combined_mask # this is for visualization
151
+
152
+ return pred_dict, batch
153
+
154
+ class GenericLoaderSmall(torch.utils.data.Dataset):
155
+ def __init__(self,data_dir,mode="single_scene",dtype=torch.float32,n_pred_views=3,pred_input_only=False,min_depth=0.1,
156
+ pointmap_for_bb=None,run_octmae=False,false_positive=None,false_negative=None):
157
+ self.data_dir = data_dir
158
+ self.mode = mode
159
+ self.dtype = dtype
160
+ self.rng = np.random.RandomState(seed=42)
161
+ self.n_pred_views = n_pred_views
162
+ self.min_depth = self.depth_metric_to_uint16(min_depth)
163
+ if self.mode == "single_scene":
164
+ self.inputs = [data_dir]
165
+ self.pred_input_only = pred_input_only
166
+ if self.pred_input_only:
167
+ self.n_pred_views = 1
168
+ self.desired_resolution = (480,640)
169
+ self.resize_transform_rgb = transforms.Resize(self.desired_resolution)
170
+ self.resize_transform_depth = transforms.Resize(self.desired_resolution,interpolation=transforms.InterpolationMode.NEAREST)
171
+ self.pointmap_for_bb = pointmap_for_bb
172
+ self.run_octmae = run_octmae
173
+ self.false_positive = false_positive
174
+ self.false_negative = false_negative
175
+
176
+ def transform_pointmap(self,pointmap_cam,c2w):
177
+ # pointmap: shape H x W x 3
178
+ # cw2: shape 4 x 4
179
+ # we want to transform the pointmap to the world frame
180
+ pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
181
+ pointmap_world_h = pointmap_cam_h @ c2w.T
182
+ pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
183
+ return pointmap_world
184
+
185
+ def __len__(self):
186
+ return len(self.inputs)
187
+
188
+ def look_at(self,cam_pos, center=(0,0,0), up=(0,0,1)):
189
+ z = center - cam_pos
190
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
191
+ y = -np.float32(up)
192
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
193
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
194
+ x = np.cross(y, z, axis=-1)
195
+
196
+ cam2w = np.r_[np.c_[x,y,z,cam_pos],[[0,0,0,1]]]
197
+ return cam2w.astype(np.float32)
198
+
199
+ def find_new_views(self,n_views,geometric_median = (0,0,0),r_min=0.4,r_max=0.9):
200
+ rad = self.rng.uniform(r_min,r_max, size=n_views)
201
+ azi = self.rng.uniform(0, 2*np.pi, size=n_views)
202
+ ele = self.rng.uniform(-np.pi, np.pi, size=n_views)
203
+ cam_centers = np.c_[np.cos(azi), np.sin(azi)]
204
+ cam_centers = rad[:,None] * np.c_[np.cos(ele)[:,None]*cam_centers, np.sin(ele)] + geometric_median
205
+
206
+ c2ws = [self.look_at(cam_pos=cam_center,center=geometric_median) for cam_center in cam_centers]
207
+ return c2ws
208
+
209
+ def depth_uint16_to_metric(self,depth):
210
+ return depth / torch.iinfo(torch.uint16).max * 10.0 # threshold is in m, convert to uint16 value
211
+
212
+ def depth_metric_to_uint16(self,depth):
213
+ return depth * torch.iinfo(torch.uint16).max / 10.0 # threshold is in m, convert to uint16 value
214
+
215
+ def resize(self,depth,img,mask,K):
216
+ s_x = self.desired_resolution[1] / img.shape[1]
217
+ s_y = self.desired_resolution[0] / img.shape[0]
218
+ depth = self.resize_transform_depth(depth.unsqueeze(0)).squeeze(0)
219
+ img = self.resize_transform_rgb(img.permute(-1,0,1)).permute(1,2,0)
220
+ mask = self.resize_transform_depth(mask.unsqueeze(0)).squeeze(0)
221
+ K[0] *= s_x
222
+ K[1] *= s_y
223
+ return depth, img, mask, K
224
+
225
+ def add_false_positives_and_negatives(self,valid_mask,false_positive,false_negative):
226
+ # add false positives to the valid mask
227
+ # add false negatives to the valid mask
228
+ # return the new valid mask
229
+ n_total_pixels = valid_mask.sum()
230
+ n_pixels_left = n_total_pixels * (1-false_positive)
231
+
232
+ mask_pixels_coords = torch.where(valid_mask)
233
+ left_pixels_coords = torch.where(~valid_mask)
234
+
235
+ # false positives
236
+ n_false_positives = min(int(n_pixels_left * false_positive),n_pixels_left)
237
+ # randomly sample n_false_positives from mask_pixels_coords
238
+ false_positives = torch.randperm(len(left_pixels_coords[0]))[:n_false_positives]
239
+ valid_mask[left_pixels_coords[0][false_positives],left_pixels_coords[1][false_positives]] = 1
240
+
241
+ # false negatives
242
+ n_false_negatives = min(int(n_total_pixels * false_negative),n_total_pixels)
243
+ # randomly sample n_false_negatives from left_pixels_coords
244
+ false_negatives = torch.randperm(len(mask_pixels_coords[0]))[:n_false_negatives]
245
+ valid_mask[mask_pixels_coords[0][false_negatives],mask_pixels_coords[1][false_negatives]] = 0
246
+
247
+ return valid_mask
248
+
249
+ def __getitem__(self,idx):
250
+ scene_dir = self.inputs[idx]
251
+
252
+ data = dict(new_cams={},input_cams={})
253
+
254
+ c2w_path = os.path.join(scene_dir,'cam2world.pt')
255
+ if os.path.exists(c2w_path):
256
+ data['input_cams']['c2ws_original'] = [torch.load(c2w_path,map_location='cpu',weights_only=True).to(self.dtype)]
257
+ else:
258
+ data['input_cams']['c2ws_original'] = [torch.eye(4).to(self.dtype)]
259
+
260
+ data['input_cams']['c2ws'] = [torch.eye(4).to(self.dtype)]
261
+ data['input_cams']['Ks'] = [torch.load(os.path.join(scene_dir,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
262
+ data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'depth.png'))).astype(np.float32))]
263
+ data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'mask.png')))).bool()]
264
+ data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'rgb.png'))))]
265
+
266
+ if self.false_positive is not None or self.false_negative is not None:
267
+ data['input_cams']['valid_masks'][0] = self.add_false_positives_and_negatives(data['input_cams']['valid_masks'][0],self.false_positive,self.false_negative)
268
+
269
+ if data['input_cams']['depths'][0].shape != self.desired_resolution:
270
+ data['input_cams']['depths'][0], data['input_cams']['imgs'][0], data['input_cams']['valid_masks'][0], data['input_cams']['Ks'][0] = \
271
+ self.resize(data['input_cams']['depths'][0], data['input_cams']['imgs'][0], data['input_cams']['valid_masks'][0], data['input_cams']['Ks'][0])
272
+
273
+ data['input_cams']['original_valid_masks'] = [data['input_cams']['valid_masks'][0].clone()]
274
+ data['input_cams']['valid_masks'][0] = data['input_cams']['valid_masks'][0] & \
275
+ (data['input_cams']['depths'][0] > self.min_depth)
276
+
277
+ if self.pred_input_only:
278
+ c2ws = [data['input_cams']['c2ws'][0].cpu().numpy()]
279
+ else:
280
+ input_mask = data['input_cams']['valid_masks'][0]
281
+ if self.pointmap_for_bb is not None:
282
+ pointmap_input = self.pointmap_for_bb
283
+ else:
284
+ pointmap_input = compute_pointmap_torch(self.depth_uint16_to_metric(data['input_cams']['depths'][0]),data['input_cams']['c2ws'][0],data['input_cams']['Ks'][0],device='cpu')[input_mask]
285
+ c2ws = pointmap_to_poses(pointmap_input, self.n_pred_views, inner_radius=1.1, outer_radius=2.5, device='cpu',run_octmae=self.run_octmae)
286
+ self.n_pred_views = len(c2ws)
287
+
288
+ data['new_cams'] = {}
289
+ data['new_cams']['c2ws'] = [torch.from_numpy(c2w).to(self.dtype) for c2w in c2ws]
290
+ data['new_cams']['depths'] = [torch.zeros_like(data['input_cams']['depths'][0]) for _ in range(self.n_pred_views)]
291
+ data['new_cams']['Ks'] = [data['input_cams']['Ks'][0] for _ in range(self.n_pred_views)]
292
+ if self.pred_input_only:
293
+ data['new_cams']['valid_masks'] = data['input_cams']['original_valid_masks']
294
+ else:
295
+ data['new_cams']['valid_masks'] = [torch.ones_like(data['input_cams']['valid_masks'][0]) for _ in range(self.n_pred_views)]
296
+
297
+ return data
298
+
299
+ def dict_to_float(d):
300
+ return {k: v.float() for k, v in d.items()}
301
+
302
+ def merge_dicts(d1,d2):
303
+ # stack the tensors along dimension 1
304
+ for k,v in d1.items():
305
+ d1[k] = torch.cat([d1[k],d2[k]],dim=1)
306
+ return d1
307
+
308
+ def compute_all_points(pred_dict,batch):
309
+ n_views = pred_dict['depths'].shape[1]
310
+ all_points = None
311
+ for i in range(n_views):
312
+ mask = batch['new_cams']['valid_masks'][0,i]
313
+ pointmap = compute_pointmap_torch(pred_dict['depths'][0,i],batch['new_cams']['c2ws'][0,i],batch['new_cams']['Ks'][0,i])
314
+ masked_points = pointmap[mask]
315
+ if all_points is None:
316
+ all_points = masked_points
317
+ else:
318
+ all_points = torch.cat([all_points,masked_points],dim=0)
319
+ return all_points
320
+
321
+ def eval_scene(model, data_dir,visualize=False,rr_addr=None,run_octmae=False,set_conf=5,
322
+ no_input_mask=False,no_pred_mask=False,no_filter_input_view=False,false_positive=None,false_negative=None,n_pred_views=5,
323
+ do_filter_all_masks=False, dino_model=None,tsdf=False):
324
+
325
+ if dino_model is None:
326
+ # Loading DINOv2 model
327
+ dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
328
+ dino_model.eval()
329
+ dino_model.to("cuda")
330
+
331
+ dataloader_input_view = GenericLoaderSmall(data_dir,n_pred_views=1,pred_input_only=True,false_positive=false_positive,false_negative=false_negative)
332
+ input_view_loader = DataLoader(dataloader_input_view, batch_size=1, shuffle=True, collate_fn=collate)
333
+ input_view_batch = next(iter(input_view_loader))
334
+
335
+ postprocessor_input_view = PostProcessWrapper(mode='input_view',set_conf=set_conf,
336
+ no_input_mask=no_input_mask,no_pred_mask=no_pred_mask)
337
+ postprocessor_pred_views = PostProcessWrapper(mode='novel_views',debug=False,set_conf=set_conf,
338
+ no_input_mask=no_input_mask,no_pred_mask=no_pred_mask)
339
+ fused_meshes = None
340
+ with torch.no_grad():
341
+ pred_input_view, gt_input_view, _, scale_factor = model(input_view_batch,dino_model)
342
+ if no_filter_input_view:
343
+ pred_input_view['pointmaps'] = input_view_batch['input_cams']['pointmaps']
344
+ pred_input_view['depths'] = input_view_batch['input_cams']['depths']
345
+ else:
346
+ pred_input_view, input_view_batch = postprocessor_input_view(pred_input_view,input_view_batch)
347
+
348
+ input_points = pred_input_view['pointmaps'][0][0][input_view_batch['new_cams']['valid_masks'][0][0]] * (1.0/scale_factor)
349
+ if input_points.shape[0] == 0:
350
+ input_points = None
351
+
352
+ dataloader_pred_views = GenericLoaderSmall(data_dir,n_pred_views=n_pred_views,pred_input_only=False,
353
+ pointmap_for_bb=input_points,run_octmae=run_octmae)
354
+ pred_views_loader = DataLoader(dataloader_pred_views, batch_size=1, shuffle=True, collate_fn=collate)
355
+ pred_views_batch = next(iter(pred_views_loader))
356
+
357
+ # this is for the mask ablation
358
+ if (false_positive is not None or false_negative is not None) and input_points is not None:
359
+ pred_views_batch['input_cams']['valid_masks'] = input_view_batch['input_cams']['valid_masks']
360
+
361
+ pred_new_views, gt_new_views, _, scale_factor = model(pred_views_batch,dino_model)
362
+ pred_new_views, pred_views_batch = postprocessor_pred_views(pred_new_views,pred_views_batch)
363
+
364
+ pred = merge_dicts(dict_to_float(pred_input_view),dict_to_float(pred_new_views))
365
+ gt = merge_dicts(dict_to_float(gt_input_view),dict_to_float(gt_new_views))
366
+
367
+ batch = copy.deepcopy(input_view_batch)
368
+ batch['new_cams'] = merge_dicts(input_view_batch['new_cams'],pred_views_batch['new_cams'])
369
+ gt['pointmaps'] = None # make sure it's not used in viz
370
+
371
+ if do_filter_all_masks:
372
+ batch = filter_all_masks(pred,input_view_batch,max_outlier_views=1)
373
+
374
+ # scale factor is the scale we applied to the input view for inference
375
+ all_points = compute_all_points(pred,batch)
376
+ all_points = all_points*(1.0/scale_factor)
377
+
378
+ # transform all_points to the original coordinate system
379
+ all_points_h = torch.cat([all_points,torch.ones(all_points.shape[:-1]+(1,)).to(all_points.device)],dim=-1)
380
+ all_points_original = all_points_h @ batch['input_cams']['c2ws_original'][0][0].T
381
+ all_points = all_points_original[...,:3]
382
+
383
+ # uncomment this to visualize a simple TSDF
384
+ if tsdf:
385
+ fused_meshes = fuse_batch(pred,gt,batch,voxel_size=0.002)
386
+ else:
387
+ fused_meshes = None
388
+
389
+ if visualize:
390
+ just_load_viz(pred, gt, batch, addr=rr_addr,fused_meshes=fused_meshes)
391
+ return all_points
392
+
393
+
394
+ def main():
395
+ parser = argparse.ArgumentParser()
396
+ parser.add_argument("data_dir", type=str)
397
+ parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876"))
398
+ parser.add_argument("--visualize", action="store_true", default=False)
399
+ parser.add_argument("--run_octmae", action="store_true", default=False)
400
+ parser.add_argument("--set_conf", type=float, default=5)
401
+ parser.add_argument("--n_pred_views", type=int, default=5)
402
+ parser.add_argument("--filter_all_masks", action="store_true", default=False)
403
+ parser.add_argument("--tsdf", action="store_true", default=False)
404
+ # ablation settings
405
+ parser.add_argument("--no_input_mask", action="store_true", default=False)
406
+ parser.add_argument("--no_pred_mask", action="store_true", default=False)
407
+ parser.add_argument("--no_filter_input_view", action="store_true", default=False)
408
+ parser.add_argument("--false_positive", type=float, default=None)
409
+ parser.add_argument("--false_negative", type=float, default=None)
410
+ args = parser.parse_args()
411
+
412
+ print("Loading checkpoint from Huggingface")
413
+ rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
414
+
415
+ model = EvalWrapper(rayst3r_checkpoint,distributed=False)
416
+ all_points = eval_scene(model, args.data_dir,visualize=args.visualize,rr_addr=args.rr_addr,run_octmae=args.run_octmae,set_conf=args.set_conf,
417
+ no_input_mask=args.no_input_mask,no_pred_mask=args.no_pred_mask,no_filter_input_view=args.no_filter_input_view,false_positive=args.false_positive,
418
+ false_negative=args.false_negative,n_pred_views=args.n_pred_views,
419
+ do_filter_all_masks=args.filter_all_masks,tsdf=args.tsdf).cpu().numpy()
420
+ all_points_save = os.path.join(args.data_dir,"inference_points.ply")
421
+ o3d_pc = npy2ply(all_points,colors=None,normals=None)
422
+ o3d.io.write_point_cloud(all_points_save, o3d_pc)
423
+
424
+ if __name__ == "__main__":
425
+ main()
eval_wrapper/eval_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from scipy.stats import norm, lognorm
4
+ import torch
5
+ import open3d as o3d
6
+
7
+ def colorize_points_with_turbo_all_dims(points, method='norm',cmap='turbo'):
8
+ """
9
+ Assigns colors to 3D points using the 'turbo' colormap based on a scalar computed from all 3 dimensions.
10
+
11
+ Args:
12
+ points (np.ndarray): (N, 3) array of 3D points.
13
+ method (str): Method for reducing 3D point to scalar. Options: 'norm', 'pca'.
14
+
15
+ Returns:
16
+ np.ndarray: (N, 3) RGB colors in [0, 1].
17
+ """
18
+ assert points.shape[1] == 3, "Input must be of shape (N, 3)"
19
+
20
+ if method == 'norm':
21
+ scalar = np.linalg.norm(points, axis=1)
22
+ elif method == 'pca':
23
+ # Project onto first principal component
24
+ mean = points.mean(axis=0)
25
+ centered = points - mean
26
+ u, s, vh = np.linalg.svd(centered, full_matrices=False)
27
+ scalar = centered @ vh[0] # Project onto first principal axis
28
+ else:
29
+ raise ValueError(f"Unknown method '{method}'")
30
+
31
+ # Normalize scalar to [0, 1]
32
+ scalar_min, scalar_max = scalar.min(), scalar.max()
33
+ normalized = (scalar - scalar_min) / (scalar_max - scalar_min + 1e-8)
34
+
35
+ # Apply turbo colormap
36
+ cmap = plt.colormaps.get_cmap(cmap)
37
+ colors = cmap(normalized)[:, :3] # Drop alpha
38
+
39
+ return colors
40
+
41
+ def npy2ply(points,colors=None,normals=None):
42
+ cloud = o3d.geometry.PointCloud()
43
+ cloud.points = o3d.utility.Vector3dVector(points.astype(np.float64))
44
+
45
+ # compute the normals
46
+ if colors is not None:
47
+ if colors.max()>1:
48
+ colors = colors/255.0
49
+ cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
50
+ else:
51
+ colors = colorize_points_with_turbo_all_dims(points)
52
+ cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
53
+ if normals is not None:
54
+ cloud.normals = o3d.utility.Vector3dVector(normals.astype(np.float64))
55
+ return cloud
56
+
57
+ def transform_pointmap(pointmap_cam,c2w):
58
+ # pointmap: shape H x W x 3
59
+ # cw2: shape 4 x 4
60
+ # we want to transform the pointmap to the world frame
61
+ pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
62
+ pointmap_world_h = pointmap_cam_h @ c2w.T
63
+ pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
64
+ return pointmap_world
65
+
66
+ def filter_all_masks(pred_dict, batch, max_outlier_views=1):
67
+ pred_masks = (torch.sigmoid(pred_dict['classifier'][0]).float() < 0.5).bool() # [V, H, W]
68
+ n_views, H, W = pred_masks.shape
69
+ device = pred_masks.device
70
+
71
+ K = batch['input_cams']['Ks'][0][0] # [3, 3]
72
+ c2ws = batch['new_cams']['c2ws'][0] # [V, 4, 4]
73
+ w2cs = torch.linalg.inv(c2ws) # [V, 4, 4]
74
+
75
+ pointmaps = pred_dict['pointmaps'][0] # [V, H, W, 3]
76
+ pointmaps_h = torch.cat([pointmaps, torch.ones_like(pointmaps[..., :1])], dim=-1) # [V, H, W, 4]
77
+
78
+ visibility_count = torch.zeros((n_views, H, W), dtype=torch.int32, device=device)
79
+
80
+ for j in range(n_views):
81
+ # Project pointmap j to all other views i ≠ j
82
+ pmap_h = pointmaps_h[j] # [H, W, 4], world-space points from view j
83
+ pmap_h = pmap_h.view(1, H, W, 4).expand(n_views, -1, -1, -1) # [V, H, W, 4]
84
+
85
+ # Compute T_{i←j} = w2cs[i] @ c2ws[j]
86
+ T = w2cs @ c2ws[j] # [V, 4, 4]
87
+ T = T.view(n_views, 1, 1, 4, 4) # [V, 1, 1, 4, 4]
88
+
89
+ # Transform to i-th camera frame
90
+ pts_cam = torch.matmul(T, pmap_h.unsqueeze(-1)).squeeze(-1)[..., :3] # [V, H, W, 3]
91
+
92
+ # Project to image
93
+ img_coords = torch.matmul(pts_cam, K.T) # [V, H, W, 3]
94
+ img_coords = img_coords[..., :2] / img_coords[..., 2:3].clamp(min=1e-6)
95
+ img_coords = img_coords.round().long() # [V, H, W, 2]
96
+
97
+ x = img_coords[..., 0].clamp(0, W - 1)
98
+ y = img_coords[..., 1].clamp(0, H - 1)
99
+ valid = (img_coords[..., 0] >= 0) & (img_coords[..., 0] < W) & \
100
+ (img_coords[..., 1] >= 0) & (img_coords[..., 1] < H)
101
+
102
+ # Get depth of the reprojected point from j into i
103
+ reprojected_depth = pts_cam[..., 2] # [V, H, W]
104
+
105
+ # Get depth of each view's original pointmap
106
+ target_depth = pointmaps[:, :, :, 2] # [V, H, W]
107
+
108
+ # Lookup the depth value in view i at the projected location (x, y)
109
+ depth_at_pixel = target_depth[torch.arange(n_views).view(-1, 1, 1), y, x] # [V, H, W]
110
+
111
+ # Check that the point is in front (closest along ray)
112
+ is_closest = reprojected_depth < depth_at_pixel # [V, H, W]
113
+
114
+ # Lookup mask values at projected location
115
+ projected_mask = pred_masks[torch.arange(n_views).view(-1, 1, 1), y, x] & valid # [V, H, W]
116
+
117
+ # Only consider as visible if it’s within mask and closest point
118
+ visible = projected_mask & is_closest # [V, H, W]
119
+
120
+ # Count how many views see each pixel from j
121
+ visibility_count[j] = visible.sum(dim=0)
122
+
123
+ visibility_mask = (visibility_count <= max_outlier_views).bool()
124
+ batch['new_cams']['valid_masks'] = visibility_mask & batch['new_cams']['valid_masks']
125
+ return batch
eval_wrapper/sample_poses.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import open3d as o3d
4
+
5
+
6
+ def look_at(cam_pos, target=(0,0,0), up=(0,0,1)):
7
+ # Forward vector
8
+ forward = target - cam_pos
9
+ forward /= np.linalg.norm(forward)
10
+
11
+ # Default up vector
12
+ right = np.cross(up, forward)
13
+ if np.linalg.norm(right) < 1e-6:
14
+ up = np.array([1, 0, 0])
15
+ right = np.cross(up, forward)
16
+
17
+ right /= np.linalg.norm(right)
18
+ up = np.cross(forward, right)
19
+
20
+ # Build rotation and translation matrices
21
+ rotation = np.eye(4)
22
+ rotation[:3, :3] = np.vstack([right, up, -forward]).T
23
+
24
+
25
+ translation = np.eye(4)
26
+ translation[:3, 3] = cam_pos
27
+
28
+ cam_to_world = translation @ rotation
29
+ cam_to_world[:3,2] = -cam_to_world[:3,2]
30
+ cam_to_world[:3,1] = -cam_to_world[:3,1]
31
+ # rotate 90 degrees around z axis
32
+ return cam_to_world
33
+
34
+
35
+ def sample_camera_poses(target: np.ndarray, inner_radius: float, outer_radius: float, n: int,seed: int = 42,mode: str = 'grid') -> np.ndarray:
36
+ """
37
+ Samples `n` camera poses uniformly on a sphere of given `radius` around `target`.
38
+ The cameras are positioned randomly and oriented to look at `target`.
39
+
40
+ Args:
41
+ target (np.ndarray): 3D point (x, y, z) that cameras should look at.
42
+ inner_radius (float): Radius of the sphere.
43
+ outer_radius (float): Radius of the sphere.
44
+ n (int): Number of camera poses to sample.
45
+
46
+ Returns:
47
+ torch.Tensor: (n, 4, 4) array of transformation matrices (camera-to-world).
48
+ """
49
+ cameras = []
50
+ np.random.seed(seed)
51
+
52
+ u_1 = np.linspace(0,1,n,endpoint=False)
53
+ u_2 = np.linspace(0,0.7,n)
54
+ u_1, u_2 = np.meshgrid(u_1, u_2)
55
+ u_1 = u_1.flatten()
56
+ u_2 = u_2.flatten()
57
+ theta = np.arccos(1-2*u_2)
58
+ phi = 2*np.pi*u_1
59
+ n_poses = len(phi)
60
+
61
+ radii = np.random.uniform(inner_radius, outer_radius, n_poses)
62
+ cameras = []
63
+
64
+ r_z = np.array([[0,-1,0],[1,0,0],[0,0,1]])
65
+
66
+ for i in range(n_poses):
67
+ # Camera position on the sphere
68
+ x = target[0] + radii[i] * np.sin(theta[i]) * np.cos(phi[i])
69
+ y = target[1] + radii[i] * np.sin(theta[i]) * np.sin(phi[i])
70
+ z = target[2] + radii[i] * np.cos(theta[i])
71
+ cam_pos = np.array([x, y, z])
72
+ cam2world = look_at(cam_pos, target)
73
+ if theta[i] == 0:
74
+ cam2world[:3,:3] = cam2world[:3,:3] @ r_z # rotate 90 degrees around z axis for the camera opposite to the input
75
+ cameras.append(cam2world)
76
+ cameras = np.unique(cameras, axis=0)
77
+ return np.stack(cameras)
78
+
79
+
80
+ def pointmap_to_poses(pointmaps: torch.Tensor, n_poses: int, inner_radius: float = 1.1, outer_radius: float = 2.5, device: str = 'cuda',
81
+ bb_mode: str='bb',run_octmae: bool = False) -> np.ndarray:
82
+ """
83
+ Samples `n_poses` camera poses uniformly on a sphere of given `radius` around `target`.
84
+ The cameras are positioned randomly and oriented to look at `target`.
85
+ """
86
+
87
+ bb_min_corner = pointmaps.min(dim=0)[0].cpu().numpy()
88
+ bb_max_corner = pointmaps.max(dim=0)[0].cpu().numpy()
89
+ center = (bb_min_corner + bb_max_corner) / 2 #inner_radius = inner_radius * np.linalg.norm(bb_max_corner - bb_min_corner) / 2 # minimum radius is scalar multiple of bounding box radius
90
+ bb_radius = np.linalg.norm(bb_max_corner - bb_min_corner) / 2
91
+ cam2center_dist = np.linalg.norm(center)
92
+
93
+ if run_octmae:
94
+ radius = max(1.2*cam2center_dist,2.5*bb_radius)
95
+ else:
96
+ radius = max(0.7*cam2center_dist,1.3*bb_radius)
97
+ inner_radius = radius
98
+ outer_radius = radius
99
+ camera_poses = sample_camera_poses(center, inner_radius, outer_radius, n_poses)
100
+ return camera_poses
example_scene/cam2world.pt ADDED
Binary file (1.25 kB). View file
 
example_scene/intrinsics.pt ADDED
Binary file (1.2 kB). View file
 
extensions/curope/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from .curope2d import cuRoPE2D
extensions/curope/curope.cpp ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+
8
+ // forward declaration
9
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10
+
11
+ void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12
+ {
13
+ const int B = tokens.size(0);
14
+ const int N = tokens.size(1);
15
+ const int H = tokens.size(2);
16
+ const int D = tokens.size(3) / 4;
17
+
18
+ auto tok = tokens.accessor<float, 4>();
19
+ auto pos = positions.accessor<int64_t, 3>();
20
+
21
+ for (int b = 0; b < B; b++) {
22
+ for (int x = 0; x < 2; x++) { // y and then x (2d)
23
+ for (int n = 0; n < N; n++) {
24
+
25
+ // grab the token position
26
+ const int p = pos[b][n][x];
27
+
28
+ for (int h = 0; h < H; h++) {
29
+ for (int d = 0; d < D; d++) {
30
+ // grab the two values
31
+ float u = tok[b][n][h][d+0+x*2*D];
32
+ float v = tok[b][n][h][d+D+x*2*D];
33
+
34
+ // grab the cos,sin
35
+ const float inv_freq = fwd * p / powf(base, d/float(D));
36
+ float c = cosf(inv_freq);
37
+ float s = sinf(inv_freq);
38
+
39
+ // write the result
40
+ tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41
+ tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+ }
48
+
49
+ void rope_2d( torch::Tensor tokens, // B,N,H,D
50
+ const torch::Tensor positions, // B,N,2
51
+ const float base,
52
+ const float fwd )
53
+ {
54
+ TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55
+ TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56
+ TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57
+ TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58
+ TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59
+ TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60
+
61
+ if (tokens.is_cuda())
62
+ rope_2d_cuda( tokens, positions, base, fwd );
63
+ else
64
+ rope_2d_cpu( tokens, positions, base, fwd );
65
+ }
66
+
67
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68
+ m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69
+ }
extensions/curope/curope.egg-info/PKG-INFO ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: curope
3
+ Version: 0.0.0
4
+ Summary: UNKNOWN
5
+ Home-page: UNKNOWN
6
+ License: UNKNOWN
7
+ Platform: UNKNOWN
8
+
9
+ UNKNOWN
10
+
extensions/curope/curope.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __init__.py
2
+ curope.cpp
3
+ curope2d.py
4
+ kernels.cu
5
+ setup.py
6
+ curope.egg-info/PKG-INFO
7
+ curope.egg-info/SOURCES.txt
8
+ curope.egg-info/dependency_links.txt
9
+ curope.egg-info/top_level.txt
extensions/curope/curope.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
extensions/curope/curope.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ curope
extensions/curope/curope2d.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+
6
+ try:
7
+ import curope as _kernels # run `python setup.py install`
8
+ except ModuleNotFoundError:
9
+ from . import curope as _kernels # run `python setup.py build_ext --inplace`
10
+
11
+ from torch.amp import custom_fwd, custom_bwd
12
+
13
+ class cuRoPE2D_func (torch.autograd.Function):
14
+
15
+ @staticmethod
16
+ @custom_fwd(device_type='cuda', cast_inputs=torch.float32)
17
+ def forward(ctx, tokens, positions, base, F0=1):
18
+ ctx.save_for_backward(positions)
19
+ ctx.saved_base = base
20
+ ctx.saved_F0 = F0
21
+ # tokens = tokens.clone() # uncomment this if inplace doesn't work
22
+ _kernels.rope_2d( tokens, positions, base, F0 )
23
+ ctx.mark_dirty(tokens)
24
+ return tokens
25
+
26
+ @staticmethod
27
+ @custom_bwd(device_type='cuda')
28
+ def backward(ctx, grad_res):
29
+ positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
30
+ _kernels.rope_2d( grad_res, positions, base, -F0 )
31
+ ctx.mark_dirty(grad_res)
32
+ return grad_res, None, None, None
33
+
34
+
35
+ class cuRoPE2D(torch.nn.Module):
36
+ def __init__(self, freq=100.0, F0=1.0):
37
+ super().__init__()
38
+ self.base = freq
39
+ self.F0 = F0
40
+
41
+ def forward(self, tokens, positions):
42
+ cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
43
+ return tokens
extensions/curope/kernels.cu ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <cuda.h>
8
+ #include <cuda_runtime.h>
9
+ #include <vector>
10
+
11
+ #define CHECK_CUDA(tensor) {\
12
+ TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13
+ TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14
+ void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15
+
16
+
17
+ template < typename scalar_t >
18
+ __global__ void rope_2d_cuda_kernel(
19
+ //scalar_t* __restrict__ tokens,
20
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> tokens,
21
+ const int64_t* __restrict__ pos,
22
+ const float base,
23
+ const float fwd )
24
+ // const int N, const int H, const int D )
25
+ {
26
+ // tokens shape = (B, N, H, D)
27
+ const int N = tokens.size(1);
28
+ const int H = tokens.size(2);
29
+ const int D = tokens.size(3);
30
+
31
+ // each block update a single token, for all heads
32
+ // each thread takes care of a single output
33
+ extern __shared__ float shared[];
34
+ float* shared_inv_freq = shared + D;
35
+
36
+ const int b = blockIdx.x / N;
37
+ const int n = blockIdx.x % N;
38
+
39
+ const int Q = D / 4;
40
+ // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41
+ // u_Y v_Y u_X v_X
42
+
43
+ // shared memory: first, compute inv_freq
44
+ if (threadIdx.x < Q)
45
+ shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46
+ __syncthreads();
47
+
48
+ // start of X or Y part
49
+ const int X = threadIdx.x < D/2 ? 0 : 1;
50
+ const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51
+
52
+ // grab the cos,sin appropriate for me
53
+ const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54
+ const float cos = cosf(freq);
55
+ const float sin = sinf(freq);
56
+ /*
57
+ float* shared_cos_sin = shared + D + D/4;
58
+ if ((threadIdx.x % (D/2)) < Q)
59
+ shared_cos_sin[m+0] = cosf(freq);
60
+ else
61
+ shared_cos_sin[m+Q] = sinf(freq);
62
+ __syncthreads();
63
+ const float cos = shared_cos_sin[m+0];
64
+ const float sin = shared_cos_sin[m+Q];
65
+ */
66
+
67
+ for (int h = 0; h < H; h++)
68
+ {
69
+ // then, load all the token for this head in shared memory
70
+ shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71
+ __syncthreads();
72
+
73
+ const float u = shared[m];
74
+ const float v = shared[m+Q];
75
+
76
+ // write output
77
+ if ((threadIdx.x % (D/2)) < Q)
78
+ tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79
+ else
80
+ tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81
+ }
82
+ }
83
+
84
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85
+ {
86
+ const int B = tokens.size(0); // batch size
87
+ const int N = tokens.size(1); // sequence length
88
+ const int H = tokens.size(2); // number of heads
89
+ const int D = tokens.size(3); // dimension per head
90
+
91
+ TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92
+ TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93
+ TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94
+ TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95
+
96
+ // one block for each layer, one thread per local-max
97
+ const int THREADS_PER_BLOCK = D;
98
+ const int N_BLOCKS = B * N; // each block takes care of H*D values
99
+ const int SHARED_MEM = sizeof(float) * (D + D/4);
100
+
101
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102
+ rope_2d_cuda_kernel<scalar_t> <<<N_BLOCKS, THREADS_PER_BLOCK, SHARED_MEM>>> (
103
+ //tokens.data_ptr<scalar_t>(),
104
+ tokens.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
105
+ pos.data_ptr<int64_t>(),
106
+ base, fwd); //, N, H, D );
107
+ }));
108
+ }
extensions/curope/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from setuptools import setup
5
+ from torch import cuda
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ # compile for all possible CUDA architectures
9
+ all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
10
+ # alternatively, you can list cuda archs that you want, eg:
11
+ # all_cuda_archs = [
12
+ # '-gencode', 'arch=compute_70,code=sm_70',
13
+ # '-gencode', 'arch=compute_75,code=sm_75',
14
+ # '-gencode', 'arch=compute_80,code=sm_80',
15
+ # '-gencode', 'arch=compute_86,code=sm_86'
16
+ # ]
17
+
18
+ setup(
19
+ name = 'curope',
20
+ ext_modules = [
21
+ CUDAExtension(
22
+ name='curope',
23
+ sources=[
24
+ "curope.cpp",
25
+ "kernels.cu",
26
+ ],
27
+ extra_compile_args = dict(
28
+ nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
29
+ cxx=['-O3'])
30
+ )
31
+ ],
32
+ cmdclass = {
33
+ 'build_ext': BuildExtension
34
+ })
input/cam2world.pt ADDED
Binary file (1.25 kB). View file
 
input/intrinsics.pt ADDED
Binary file (1.07 kB). View file
 
main.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bb = breakpoint
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ import wandb
5
+ from argparse import ArgumentParser
6
+ from datasets.octmae import OctMae
7
+ from datasets.foundation_pose import FoundationPose
8
+ from datasets.generic_loader import GenericLoader
9
+
10
+ from utils.collate import collate
11
+ from models.rayquery import RayQuery
12
+ from engine import train_epoch, eval_epoch, eval_model
13
+ import torch.nn as nn
14
+ from models.rayquery import RayQuery, PointmapEncoder, RayEncoder
15
+ from models.losses import *
16
+ import utils.misc as misc
17
+ import os
18
+ from utils.viz import just_load_viz
19
+ from utils.fusion import fuse_batch
20
+ import socket
21
+ import time
22
+ from utils.augmentations import *
23
+
24
+ def parse_args():
25
+ parser = ArgumentParser()
26
+ parser.add_argument("--dataset_train", type=str, default="TableOfCubes(size=10,n_views=2,seed=747)")
27
+ parser.add_argument("--dataset_test", type=str, default="TableOfCubes(size=10,n_views=2,seed=787)")
28
+ parser.add_argument("--dataset_just_load", type=str, default=None)
29
+ parser.add_argument("--logdir", type=str, default="logs")
30
+ parser.add_argument("--batch_size", type=int, default=5)
31
+ parser.add_argument("--n_epochs", type=int, default=100)
32
+ parser.add_argument("--n_workers", type=int, default=4)
33
+ parser.add_argument("--model", type=str, default="RayQuery(ray_enc=RayEncoder(),pointmap_enc=PointmapEncoder(),criterion=RayCompletion(ConfLoss(L21)))")
34
+ parser.add_argument("--save_every", type=int, default=1)
35
+ parser.add_argument("--resume", type=str, default=None)
36
+ parser.add_argument("--eval_every", type=int, default=3)
37
+ parser.add_argument("--wandb_project", type=str, default=None)
38
+ parser.add_argument("--wandb_run_name", type=str, default="init")
39
+ parser.add_argument("--just_load", action="store_true")
40
+ parser.add_argument("--device", type=str, default="cuda")
41
+ parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876"))
42
+ parser.add_argument("--mesh", action="store_true")
43
+ parser.add_argument("--max_norm", type=float, default=-1)
44
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
45
+ parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
46
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
47
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
48
+ help='lower lr bound for cyclic schedulers that hit 0')
49
+ parser.add_argument('--warmup_epochs', type=int, default=10)
50
+ parser.add_argument('--weight_decay', type=float, default=0.01)
51
+ parser.add_argument('--normalize_mode',type=str,default='None')
52
+ parser.add_argument('--start_from',type=str,default=None)
53
+ parser.add_argument('--augmentor',type=str,default='None')
54
+ return parser.parse_args()
55
+
56
+ def main(args):
57
+ load_dino = False
58
+ if not args.just_load:
59
+ dataset_train = eval(args.dataset_train)
60
+ dataset_test = eval(args.dataset_test)
61
+ if not dataset_train.prefetch_dino:
62
+ load_dino = True
63
+ rank, world_size, local_rank = misc.setup_distributed()
64
+ sampler_train = torch.utils.data.DistributedSampler(
65
+ dataset_train, num_replicas=world_size, rank=rank, shuffle=True
66
+ )
67
+
68
+ sampler_test = torch.utils.data.DistributedSampler(
69
+ dataset_test, num_replicas=world_size, rank=rank, shuffle=False
70
+ )
71
+
72
+ train_loader = DataLoader(
73
+ dataset_train, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
74
+ num_workers=args.n_workers,
75
+ pin_memory=True,
76
+ prefetch_factor=2,
77
+ drop_last=True
78
+ )
79
+ test_loader = DataLoader(
80
+ dataset_test, sampler=sampler_test, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
81
+ num_workers=args.n_workers,
82
+ pin_memory=True,
83
+ prefetch_factor=2,
84
+ drop_last=True
85
+ )
86
+
87
+ n_scenes_epoch = len(train_loader) * args.batch_size * world_size
88
+ print(f"Number of scenes in epoch: {n_scenes_epoch}")
89
+ else:
90
+ if args.dataset_just_load is None:
91
+ dataset = eval(args.dataset_train)
92
+ else:
93
+ dataset = eval(args.dataset_just_load)
94
+ if not dataset.prefetch_dino:
95
+ load_dino = True
96
+ rank, world_size, local_rank = misc.setup_distributed()
97
+ sampler_train = torch.utils.data.DistributedSampler(
98
+ dataset, num_replicas=world_size, rank=rank, shuffle=False
99
+ )
100
+ just_loader = DataLoader(dataset, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
101
+ pin_memory=True,
102
+ drop_last=True
103
+ )
104
+
105
+ model = eval(args.model).to(args.device)
106
+ if args.augmentor != 'None':
107
+ augmentor = eval(args.augmentor)
108
+ else:
109
+ augmentor = None
110
+
111
+ if load_dino and len(model.dino_layers) > 0:
112
+ dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
113
+ dino_model.eval()
114
+ dino_model.to("cuda")
115
+ else:
116
+ dino_model = None
117
+ # distribute model
118
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=True)
119
+ model_without_ddp = model.module if hasattr(model, 'module') else model
120
+
121
+ eff_batch_size = args.batch_size * misc.get_world_size()
122
+ if args.lr is None: # only base_lr is specified
123
+ args.lr = args.blr * eff_batch_size / 256
124
+
125
+ param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
126
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
127
+ os.makedirs(args.logdir,exist_ok=True)
128
+ start_epoch = 0
129
+ print("Running on host %s" % socket.gethostname())
130
+ if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-latest.pth")):
131
+ checkpoint = torch.load(os.path.join(args.resume, "checkpoint-latest.pth"), map_location='cpu')
132
+ model_without_ddp.load_state_dict(checkpoint['model'])
133
+ model_params = list(model.parameters())
134
+ print("Resume checkpoint %s" % args.resume)
135
+
136
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
137
+ optimizer.load_state_dict(checkpoint['optimizer'])
138
+ start_epoch = checkpoint['epoch'] + 1
139
+ print("With optim & sched!")
140
+ del checkpoint
141
+ elif args.start_from is not None:
142
+ checkpoint = torch.load(args.start_from, map_location='cpu')
143
+ model_without_ddp.load_state_dict(checkpoint['model'])
144
+ print("Start from checkpoint %s" % args.start_from)
145
+ if args.just_load:
146
+ with torch.no_grad():
147
+ while True:
148
+ #test_log_dict = eval_epoch(model,just_loader,device=args.device,dino_model=dino_model,args=args)
149
+ for data in just_loader:
150
+ pred, gt, loss_dict, batch = eval_model(model,data,mode='viz',args=args,dino_model=dino_model,augmentor=augmentor)
151
+ # cast to float32 for visualization
152
+ gt = {k: v.float() for k, v in gt.items()}
153
+ pred = {k: v.float() for k, v in pred.items()}
154
+ #loss_dict = eval_model(model,data,mode='loss',device=args.device)
155
+ #print(f"Loss: {loss_dict['loss']:.4f}")
156
+ # summarize all keys in loss_dict in table
157
+ print(f"{'Key':<10} {'Value':<10}")
158
+ print("-"*20)
159
+ for key, value in loss_dict.items():
160
+ print(f"{key:<10}: {value:.4f}")
161
+ print("-"*20)
162
+ name = args.logdir
163
+ addr = args.rr_addr
164
+ if args.mesh:
165
+ fused_meshes = fuse_batch(pred,gt,data, voxel_size=0.002)
166
+ else:
167
+ fused_meshes = None
168
+ just_load_viz(pred,gt,batch,addr=addr,name=name,fused_meshes=fused_meshes)
169
+ breakpoint()
170
+ return
171
+ else:
172
+ if args.wandb_project and misc.get_rank() == 0:
173
+ wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
174
+ log_wandb = args.wandb_project
175
+ else:
176
+ log_wandb = None
177
+ for epoch in range(start_epoch,args.n_epochs):
178
+ start_time = time.time()
179
+ log_dict = train_epoch(model,train_loader,optimizer,device=args.device,max_norm=args.max_norm,epoch=epoch,
180
+ log_wandb=log_wandb,batch_size=eff_batch_size,args=args,dino_model=dino_model,augmentor=augmentor)
181
+ end_time = time.time()
182
+ print(f"Epoch {epoch} train loss: {log_dict['loss']:.4f} grad_norm: {log_dict['grad_norm']:.4f} \n")
183
+ print(f"Time taken for epoch {epoch}: {end_time - start_time:.2f} seconds")
184
+
185
+ if epoch % args.eval_every == 0:
186
+ test_log_dict = eval_epoch(model,test_loader,device=args.device,dino_model=dino_model,args=args,augmentor=augmentor)
187
+ print(f"Epoch {epoch} test loss: {test_log_dict['loss']:.4f} \n")
188
+ if log_wandb:
189
+ wandb_dict = {f"test_{k}":v for k,v in test_log_dict.items()}
190
+ wandb.log(wandb_dict, step=(epoch+1)*n_scenes_epoch)
191
+ if epoch % args.save_every == 0:
192
+ # this saves the model every epoch and doesn't overwrite but it becomes tremendous, huge
193
+ #misc.save_model(args, epoch, model, optimizer)
194
+ misc.save_model(args, epoch, model_without_ddp, optimizer, epoch_name=f"latest")
195
+
196
+ if __name__ == "__main__":
197
+ args = parse_args()
198
+ main(args)
models/blocks.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from: https://github.com/naver/croco/blob/743ee71a2a9bf57cea6832a9064a70a0597fcfcb/models/blocks.py
2
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from itertools import repeat
9
+ import collections.abc
10
+
11
+ def _ntuple(n):
12
+ def parse(x):
13
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
14
+ return x
15
+ return tuple(repeat(x, n))
16
+ return parse
17
+ to_2tuple = _ntuple(2)
18
+
19
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
20
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
21
+ """
22
+ if drop_prob == 0. or not training:
23
+ return x
24
+ keep_prob = 1 - drop_prob
25
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
26
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
27
+ if keep_prob > 0.0 and scale_by_keep:
28
+ random_tensor.div_(keep_prob)
29
+ return x * random_tensor
30
+
31
+ class DropPath(nn.Module):
32
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
33
+ """
34
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
35
+ super(DropPath, self).__init__()
36
+ self.drop_prob = drop_prob
37
+ self.scale_by_keep = scale_by_keep
38
+
39
+ def forward(self, x):
40
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
41
+
42
+ def extra_repr(self):
43
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
44
+
45
+ class Mlp(nn.Module):
46
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
47
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
48
+ super().__init__()
49
+ out_features = out_features or in_features
50
+ hidden_features = hidden_features or in_features
51
+ bias = to_2tuple(bias)
52
+ drop_probs = to_2tuple(drop)
53
+
54
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
55
+ self.act = act_layer()
56
+ self.drop1 = nn.Dropout(drop_probs[0])
57
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
58
+ self.drop2 = nn.Dropout(drop_probs[1])
59
+
60
+ def forward(self, x):
61
+ x = self.fc1(x)
62
+ x = self.act(x)
63
+ x = self.drop1(x)
64
+ x = self.fc2(x)
65
+ x = self.drop2(x)
66
+ return x
67
+
68
+ class Attention(nn.Module):
69
+
70
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
71
+ super().__init__()
72
+ self.num_heads = num_heads
73
+ head_dim = dim // num_heads
74
+ self.scale = head_dim ** -0.5
75
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76
+ self.attn_drop = nn.Dropout(attn_drop)
77
+ self.proj = nn.Linear(dim, dim)
78
+ self.proj_drop = nn.Dropout(proj_drop)
79
+ self.rope = rope
80
+
81
+ def forward(self, x, xpos):
82
+ B, N, C = x.shape
83
+
84
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
85
+ q, k, v = [qkv[:,:,i] for i in range(3)]
86
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
87
+
88
+ if self.rope is not None:
89
+ q = self.rope(q, xpos)
90
+ k = self.rope(k, xpos)
91
+
92
+ attn = (q @ k.transpose(-2, -1)) * self.scale
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+
96
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+ return x
100
+
101
+ class Block(nn.Module):
102
+
103
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
104
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
105
+ super().__init__()
106
+ self.norm1 = norm_layer(dim)
107
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
108
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
109
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
110
+ self.norm2 = norm_layer(dim)
111
+ mlp_hidden_dim = int(dim * mlp_ratio)
112
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
113
+
114
+ def forward(self, x, xpos):
115
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
116
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
117
+ return x
118
+
119
+ class CrossAttention(nn.Module):
120
+
121
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
122
+ super().__init__()
123
+ self.num_heads = num_heads
124
+ head_dim = dim // num_heads
125
+ self.scale = head_dim ** -0.5
126
+
127
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
128
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
129
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+
134
+ self.rope = rope
135
+
136
+ def forward(self, query, key, value, qpos, kpos):
137
+ B, Nq, C = query.shape
138
+ Nk = key.shape[1]
139
+ Nv = value.shape[1]
140
+
141
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
142
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
143
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
144
+
145
+ if self.rope is not None:
146
+ q = self.rope(q, qpos)
147
+ k = self.rope(k, kpos)
148
+
149
+ attn = (q @ k.transpose(-2, -1)) * self.scale
150
+ attn = attn.softmax(dim=-1)
151
+ attn = self.attn_drop(attn)
152
+
153
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
154
+ x = self.proj(x)
155
+ x = self.proj_drop(x)
156
+ return x
157
+
158
+ class DecoderBlock(nn.Module):
159
+
160
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
161
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None,order='sa_ca'):
162
+ super().__init__()
163
+ self.norm1 = norm_layer(dim)
164
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
165
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
166
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
167
+ self.norm2 = norm_layer(dim)
168
+ self.norm3 = norm_layer(dim)
169
+ mlp_hidden_dim = int(dim * mlp_ratio)
170
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
171
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
172
+ self.order = order
173
+ self.batch_drop_path_prob = -drop_path if drop_path < 0. else 0.
174
+
175
+ def forward(self, x, y, xpos, ypos):
176
+ if self.order == 'sa_ca':
177
+ if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.attn(self.norm1(x), xpos))
178
+ y_ = self.norm_y(y)
179
+ if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
180
+ if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.mlp(self.norm3(x)))
181
+ elif self.order == 'ca_sa':
182
+ y_ = self.norm_y(y)
183
+ if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
184
+ if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.attn(self.norm1(x), xpos))
185
+ if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.mlp(self.norm3(x)))
186
+ return x, y
187
+
188
+
189
+ # patch embedding
190
+ class PositionGetter(object):
191
+ """ return positions of patches """
192
+
193
+ def __init__(self):
194
+ self.cache_positions = {}
195
+
196
+ def __call__(self, b, h, w, device):
197
+ if not (h,w) in self.cache_positions:
198
+ x = torch.arange(w, device=device)
199
+ y = torch.arange(h, device=device)
200
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
201
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
202
+ return pos
203
+
204
+ class PatchEmbed(nn.Module):
205
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
206
+
207
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
208
+ super().__init__()
209
+ img_size = to_2tuple(img_size)
210
+ patch_size = to_2tuple(patch_size)
211
+ self.img_size = img_size
212
+ self.patch_size = patch_size
213
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
214
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
215
+ self.flatten = flatten
216
+
217
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
218
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
219
+
220
+ self.position_getter = PositionGetter()
221
+
222
+ def forward(self, x):
223
+ B, C, H, W = x.shape
224
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
225
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
226
+ x = self.proj(x)
227
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
228
+ if self.flatten:
229
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
230
+ x = self.norm(x)
231
+ return x, pos
232
+
233
+ def _init_weights(self):
234
+ w = self.proj.weight.data
235
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
models/heads/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # head factory
6
+ # --------------------------------------------------------
7
+ from .linear_head import LinearPts3d
8
+ from .dpt_head import create_dpt_head, create_dpt_head_mask, create_dpt_head_depth
9
+
10
+ def head_factory(head_type, output_mode, net, has_conf=False):
11
+ """" build a prediction head for the decoder
12
+ """
13
+ if head_type == 'linear' and output_mode == 'pts3d':
14
+ return LinearPts3d(net, has_conf)
15
+ if head_type == 'linear_depth' and output_mode == 'pts3d':
16
+ return LinearPts3d(net, has_conf,mode='depth')
17
+ if head_type == 'linear_classifier' and output_mode == 'pts3d':
18
+ return LinearPts3d(net, has_conf,mode='classifier')
19
+ elif head_type == 'dpt' and output_mode == 'pts3d':
20
+ return create_dpt_head(net, has_conf=has_conf)
21
+ elif head_type == 'dpt_depth' and output_mode == 'pts3d':
22
+ return create_dpt_head_depth(net, has_conf=has_conf)
23
+ elif head_type == 'dpt_mask' and output_mode == 'pts3d':
24
+ return create_dpt_head_mask(net, has_conf=has_conf)
25
+ else:
26
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
models/heads/dpt_head.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+ from typing import Union, Tuple, Iterable, List, Optional, Dict
6
+ from .postprocess import postprocess
7
+
8
+ def pair(t):
9
+ return t if isinstance(t, tuple) else (t, t)
10
+
11
+ def make_scratch(in_shape, out_shape, groups=1, expand=False):
12
+ scratch = nn.Module()
13
+
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape
16
+ out_shape3 = out_shape
17
+ out_shape4 = out_shape
18
+ if expand == True:
19
+ out_shape1 = out_shape
20
+ out_shape2 = out_shape * 2
21
+ out_shape3 = out_shape * 4
22
+ out_shape4 = out_shape * 8
23
+
24
+ scratch.layer1_rn = nn.Conv2d(
25
+ in_shape[0],
26
+ out_shape1,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=1,
30
+ bias=False,
31
+ groups=groups,
32
+ )
33
+ scratch.layer2_rn = nn.Conv2d(
34
+ in_shape[1],
35
+ out_shape2,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=1,
39
+ bias=False,
40
+ groups=groups,
41
+ )
42
+ scratch.layer3_rn = nn.Conv2d(
43
+ in_shape[2],
44
+ out_shape3,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ bias=False,
49
+ groups=groups,
50
+ )
51
+ scratch.layer4_rn = nn.Conv2d(
52
+ in_shape[3],
53
+ out_shape4,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ groups=groups,
59
+ )
60
+
61
+ scratch.layer_rn = nn.ModuleList([
62
+ scratch.layer1_rn,
63
+ scratch.layer2_rn,
64
+ scratch.layer3_rn,
65
+ scratch.layer4_rn,
66
+ ])
67
+
68
+ return scratch
69
+
70
+ class ResidualConvUnit_custom(nn.Module):
71
+ """Residual convolution module."""
72
+
73
+ def __init__(self, features, activation, bn):
74
+ """Init.
75
+ Args:
76
+ features (int): number of features
77
+ """
78
+ super().__init__()
79
+
80
+ self.bn = bn
81
+ self.groups = 1
82
+
83
+ self.conv1 = nn.Conv2d(
84
+ features,
85
+ features,
86
+ kernel_size=3,
87
+ stride=1,
88
+ padding=1,
89
+ bias=not self.bn,
90
+ groups=self.groups,
91
+ )
92
+
93
+ self.conv2 = nn.Conv2d(
94
+ features,
95
+ features,
96
+ kernel_size=3,
97
+ stride=1,
98
+ padding=1,
99
+ bias=not self.bn,
100
+ groups=self.groups,
101
+ )
102
+
103
+ if self.bn == True:
104
+ self.bn1 = nn.BatchNorm2d(features)
105
+ self.bn2 = nn.BatchNorm2d(features)
106
+
107
+ self.activation = activation
108
+
109
+ self.skip_add = nn.quantized.FloatFunctional()
110
+
111
+ def forward(self, x):
112
+ """Forward pass.
113
+ Args:
114
+ x (tensor): input
115
+ Returns:
116
+ tensor: output
117
+ """
118
+
119
+ out = self.activation(x)
120
+ out = self.conv1(out)
121
+ if self.bn == True:
122
+ out = self.bn1(out)
123
+
124
+ out = self.activation(out)
125
+ out = self.conv2(out)
126
+ if self.bn == True:
127
+ out = self.bn2(out)
128
+
129
+ if self.groups > 1:
130
+ out = self.conv_merge(out)
131
+
132
+ return self.skip_add.add(out, x)
133
+
134
+ class FeatureFusionBlock_custom(nn.Module):
135
+ """Feature fusion block."""
136
+
137
+ def __init__(
138
+ self,
139
+ features,
140
+ activation,
141
+ deconv=False,
142
+ bn=False,
143
+ expand=False,
144
+ align_corners=True,
145
+ width_ratio=1,
146
+ ):
147
+ """Init.
148
+ Args:
149
+ features (int): number of features
150
+ """
151
+ super(FeatureFusionBlock_custom, self).__init__()
152
+ self.width_ratio = width_ratio
153
+
154
+ self.deconv = deconv
155
+ self.align_corners = align_corners
156
+
157
+ self.groups = 1
158
+
159
+ self.expand = expand
160
+ out_features = features
161
+ if self.expand == True:
162
+ out_features = features // 2
163
+
164
+ self.out_conv = nn.Conv2d(
165
+ features,
166
+ out_features,
167
+ kernel_size=1,
168
+ stride=1,
169
+ padding=0,
170
+ bias=True,
171
+ groups=1,
172
+ )
173
+
174
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
175
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
176
+
177
+ self.skip_add = nn.quantized.FloatFunctional()
178
+
179
+ def forward(self, *xs):
180
+ """Forward pass.
181
+ Returns:
182
+ tensor: output
183
+ """
184
+ output = xs[0]
185
+
186
+ if len(xs) == 2:
187
+ res = self.resConfUnit1(xs[1])
188
+ if self.width_ratio != 1:
189
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
190
+
191
+ output = self.skip_add.add(output, res)
192
+ # output += res
193
+
194
+ output = self.resConfUnit2(output)
195
+
196
+ if self.width_ratio != 1:
197
+ # and output.shape[3] < self.width_ratio * output.shape[2]
198
+ #size=(image.shape[])
199
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
200
+ shape = 3 * output.shape[3]
201
+ else:
202
+ shape = int(self.width_ratio * 2 * output.shape[2])
203
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
204
+ else:
205
+ output = nn.functional.interpolate(output, scale_factor=2,
206
+ mode="bilinear", align_corners=self.align_corners)
207
+ output = self.out_conv(output)
208
+ return output
209
+
210
+ def make_fusion_block(features, use_bn, width_ratio=1):
211
+ return FeatureFusionBlock_custom(
212
+ features,
213
+ nn.ReLU(False),
214
+ deconv=False,
215
+ bn=use_bn,
216
+ expand=False,
217
+ align_corners=True,
218
+ width_ratio=width_ratio,
219
+ )
220
+
221
+ class Interpolate(nn.Module):
222
+ """Interpolation module."""
223
+
224
+ def __init__(self, scale_factor, mode, align_corners=False):
225
+ """Init.
226
+ Args:
227
+ scale_factor (float): scaling
228
+ mode (str): interpolation mode
229
+ """
230
+ super(Interpolate, self).__init__()
231
+
232
+ self.interp = nn.functional.interpolate
233
+ self.scale_factor = scale_factor
234
+ self.mode = mode
235
+ self.align_corners = align_corners
236
+
237
+ def forward(self, x):
238
+ """Forward pass.
239
+ Args:
240
+ x (tensor): input
241
+ Returns:
242
+ tensor: interpolated data
243
+ """
244
+
245
+ x = self.interp(
246
+ x,
247
+ scale_factor=self.scale_factor,
248
+ mode=self.mode,
249
+ align_corners=self.align_corners,
250
+ )
251
+
252
+ return x
253
+
254
+ class DPTOutputAdapter(nn.Module):
255
+ """DPT output adapter.
256
+
257
+ :param num_cahnnels: Number of output channels
258
+ :param stride_level: stride level compared to the full-sized image.
259
+ E.g. 4 for 1/4th the size of the image.
260
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
261
+ Patch size for smaller inputs will be computed accordingly.
262
+ :param hooks: Index of intermediate layers
263
+ :param layer_dims: Dimension of intermediate layers
264
+ :param feature_dim: Feature dimension
265
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
266
+ :param use_bn: If set to True, activates batch norm
267
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
268
+ """
269
+
270
+ def __init__(self,
271
+ num_channels: int = 1,
272
+ stride_level: int = 1,
273
+ patch_size: Union[int, Tuple[int, int]] = 16,
274
+ main_tasks: Iterable[str] = ('rgb',),
275
+ hooks: List[int] = [2, 5, 8, 11],
276
+ layer_dims: List[int] = [96, 192, 384, 768],
277
+ feature_dim: int = 256,
278
+ last_dim: int = 32,
279
+ use_bn: bool = False,
280
+ dim_tokens_enc: Optional[int] = None,
281
+ head_type: str = 'regression',
282
+ output_width_ratio=1,
283
+ **kwargs):
284
+ super().__init__()
285
+ self.num_channels = num_channels
286
+ self.stride_level = stride_level
287
+ self.patch_size = pair(patch_size)
288
+ self.main_tasks = main_tasks
289
+ self.hooks = hooks
290
+ self.layer_dims = layer_dims
291
+ self.feature_dim = feature_dim
292
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
293
+ self.head_type = head_type
294
+
295
+ # Actual patch height and width, taking into account stride of input
296
+ self.P_H = max(1, self.patch_size[0] // stride_level)
297
+ self.P_W = max(1, self.patch_size[1] // stride_level)
298
+
299
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
300
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
301
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
302
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
303
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
304
+
305
+ if self.head_type == 'regression':
306
+ # The "DPTDepthModel" head
307
+ self.head = nn.Sequential(
308
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
309
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
310
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
311
+ nn.ReLU(True),
312
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
313
+ )
314
+ elif self.head_type == 'semseg':
315
+ # The "DPTSegmentationModel" head
316
+ self.head = nn.Sequential(
317
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
318
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
319
+ nn.ReLU(True),
320
+ nn.Dropout(0.1, False),
321
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
322
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
323
+ )
324
+ else:
325
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
326
+
327
+ if self.dim_tokens_enc is not None:
328
+ self.init(dim_tokens_enc=dim_tokens_enc)
329
+
330
+ def init(self, dim_tokens_enc=768):
331
+ """
332
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
333
+ Should be called when setting up MultiMAE.
334
+
335
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
336
+ """
337
+ #print(dim_tokens_enc)
338
+
339
+ # Set up activation postprocessing layers
340
+ if isinstance(dim_tokens_enc, int):
341
+ dim_tokens_enc = 4 * [dim_tokens_enc]
342
+
343
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
344
+
345
+ self.act_1_postprocess = nn.Sequential(
346
+ nn.Conv2d(
347
+ in_channels=self.dim_tokens_enc[0],
348
+ out_channels=self.layer_dims[0],
349
+ kernel_size=1, stride=1, padding=0,
350
+ ),
351
+ nn.ConvTranspose2d(
352
+ in_channels=self.layer_dims[0],
353
+ out_channels=self.layer_dims[0],
354
+ kernel_size=4, stride=4, padding=0,
355
+ bias=True, dilation=1, groups=1,
356
+ )
357
+ )
358
+
359
+ self.act_2_postprocess = nn.Sequential(
360
+ nn.Conv2d(
361
+ in_channels=self.dim_tokens_enc[1],
362
+ out_channels=self.layer_dims[1],
363
+ kernel_size=1, stride=1, padding=0,
364
+ ),
365
+ nn.ConvTranspose2d(
366
+ in_channels=self.layer_dims[1],
367
+ out_channels=self.layer_dims[1],
368
+ kernel_size=2, stride=2, padding=0,
369
+ bias=True, dilation=1, groups=1,
370
+ )
371
+ )
372
+
373
+ self.act_3_postprocess = nn.Sequential(
374
+ nn.Conv2d(
375
+ in_channels=self.dim_tokens_enc[2],
376
+ out_channels=self.layer_dims[2],
377
+ kernel_size=1, stride=1, padding=0,
378
+ )
379
+ )
380
+
381
+ self.act_4_postprocess = nn.Sequential(
382
+ nn.Conv2d(
383
+ in_channels=self.dim_tokens_enc[3],
384
+ out_channels=self.layer_dims[3],
385
+ kernel_size=1, stride=1, padding=0,
386
+ ),
387
+ nn.Conv2d(
388
+ in_channels=self.layer_dims[3],
389
+ out_channels=self.layer_dims[3],
390
+ kernel_size=3, stride=2, padding=1,
391
+ )
392
+ )
393
+
394
+ self.act_postprocess = nn.ModuleList([
395
+ self.act_1_postprocess,
396
+ self.act_2_postprocess,
397
+ self.act_3_postprocess,
398
+ self.act_4_postprocess
399
+ ])
400
+
401
+ def adapt_tokens(self, encoder_tokens):
402
+ # Adapt tokens
403
+ x = []
404
+ x.append(encoder_tokens[:, :])
405
+ x = torch.cat(x, dim=-1)
406
+ return x
407
+
408
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
409
+ #input_info: Dict):
410
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
411
+ H, W = image_size
412
+
413
+ # Number of patches in height and width
414
+ N_H = H // (self.stride_level * self.P_H)
415
+ N_W = W // (self.stride_level * self.P_W)
416
+
417
+ # Hook decoder onto 4 layers from specified ViT layers
418
+ layers = [encoder_tokens[hook] for hook in self.hooks]
419
+
420
+ # Extract only task-relevant tokens and ignore global tokens.
421
+ layers = [self.adapt_tokens(l) for l in layers]
422
+
423
+ # Reshape tokens to spatial representation
424
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
425
+
426
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
427
+ # Project layers to chosen feature dim
428
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
429
+
430
+ # Fuse layers using refinement stages
431
+ path_4 = self.scratch.refinenet4(layers[3])
432
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
433
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
434
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
435
+
436
+ # Output head
437
+ out = self.head(path_1)
438
+
439
+ return out
440
+
441
+ class DPTOutputAdapter_fix(DPTOutputAdapter):
442
+ """
443
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
444
+ remove duplicated weigths, and fix forward for dust3r
445
+ """
446
+
447
+ def init(self, dim_tokens_enc=768,**kwargs):
448
+ super().init(dim_tokens_enc,**kwargs)
449
+ # these are duplicated weights
450
+ del self.act_1_postprocess
451
+ del self.act_2_postprocess
452
+ del self.act_3_postprocess
453
+ del self.act_4_postprocess
454
+
455
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
456
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
457
+ # H, W = input_info['image_size']
458
+ image_size = self.image_size if image_size is None else image_size
459
+ H, W = image_size
460
+ # Number of patches in height and width
461
+ N_H = H // (self.stride_level * self.P_H)
462
+ N_W = W // (self.stride_level * self.P_W)
463
+
464
+ # Hook decoder onto 4 layers from specified ViT layers
465
+ layers = [encoder_tokens[hook] for hook in self.hooks]
466
+
467
+ # Extract only task-relevant tokens and ignore global tokens.
468
+ layers = [self.adapt_tokens(l) for l in layers]
469
+
470
+ # Reshape tokens to spatial representation
471
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
472
+
473
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
474
+ # Project layers to chosen feature dim
475
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
476
+
477
+ # Fuse layers using refinement stages
478
+ path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
479
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
480
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
481
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
482
+
483
+ # Output head
484
+ out = self.head(path_1)
485
+ return out
486
+
487
+
488
+ class PixelwiseTaskWithDPT(nn.Module):
489
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
490
+
491
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
492
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, classifier_mode=None, **kwargs):
493
+ super(PixelwiseTaskWithDPT, self).__init__()
494
+ self.return_all_layers = True # backbone needs to return all layers
495
+ self.postprocess = postprocess
496
+ self.depth_mode = depth_mode
497
+ self.conf_mode = conf_mode
498
+ self.classifier_mode = classifier_mode
499
+
500
+ assert n_cls_token == 0, "Not implemented"
501
+ dpt_args = dict(output_width_ratio=output_width_ratio,
502
+ num_channels=num_channels,
503
+ **kwargs)
504
+ if hooks_idx is not None:
505
+ dpt_args.update(hooks=hooks_idx)
506
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
507
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
508
+ self.dpt.init(**dpt_init_args)
509
+
510
+ def forward(self, x, img_info):
511
+ out = self.dpt(x, image_size=(img_info[0], img_info[1]))
512
+ if self.postprocess:
513
+ out = self.postprocess(out, self.depth_mode, self.conf_mode,self.classifier_mode)
514
+ return out
515
+
516
+ def create_dpt_head(net, has_conf=False):
517
+ """
518
+ return PixelwiseTaskWithDPT for given net params
519
+ """
520
+ assert net.dec_depth > 9
521
+ l2 = net.dec_depth - 1
522
+ feature_dim = 256
523
+ last_dim = feature_dim//2
524
+ out_nchan = 3
525
+ ed = net.enc_embed_dim
526
+ dd = net.dec_embed_dim
527
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
528
+ feature_dim=feature_dim,
529
+ last_dim=last_dim,
530
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
531
+ dim_tokens=[ed, dd, dd, dd],
532
+ postprocess=postprocess,
533
+ depth_mode=net.depth_mode,
534
+ conf_mode=net.conf_mode,
535
+ head_type='regression',
536
+ patch_size=net.patch_size)
537
+
538
+ def create_dpt_head_depth(net, has_conf=False):
539
+ """
540
+ return PixelwiseTaskWithDPT for given net params
541
+ """
542
+ assert net.dec_depth > 9
543
+ l2 = net.dec_depth - 1
544
+ feature_dim = 256
545
+ last_dim = feature_dim//2
546
+ out_nchan = 1
547
+ ed = net.enc_embed_dim
548
+ dd = net.dec_embed_dim
549
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
550
+ feature_dim=feature_dim,
551
+ last_dim=last_dim,
552
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
553
+ dim_tokens=[ed, dd, dd, dd],
554
+ postprocess=postprocess,
555
+ depth_mode=net.depth_mode,
556
+ conf_mode=net.conf_mode,
557
+ head_type='regression',
558
+ patch_size=net.patch_size)
559
+
560
+
561
+ def create_dpt_head_mask(net, has_conf=False):
562
+ """
563
+ return PixelwiseTaskWithDPT for given net params
564
+ """
565
+ assert net.dec_depth > 9
566
+ l2 = net.dec_depth - 1
567
+ feature_dim = 256
568
+ last_dim = feature_dim//2
569
+ out_nchan = 3
570
+ ed = net.enc_embed_dim
571
+ dd = net.dec_embed_dim
572
+ return PixelwiseTaskWithDPT(num_channels=1 + has_conf,
573
+ feature_dim=feature_dim,
574
+ last_dim=last_dim,
575
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
576
+ dim_tokens=[ed, dd, dd, dd],
577
+ postprocess=postprocess,
578
+ depth_mode=net.depth_mode,
579
+ conf_mode=net.conf_mode,
580
+ classifier_mode=net.classifier_mode,
581
+ head_type='regression',
582
+ patch_size=net.patch_size)
models/heads/linear_head.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .postprocess import postprocess
5
+
6
+ class LinearPts3d (nn.Module):
7
+ """
8
+ Linear head for dust3r
9
+ Each token outputs: - 16x16 3D points (+ confidence)
10
+ """
11
+
12
+ def __init__(self, net, has_conf=False,mode='pts3d'):
13
+ super().__init__()
14
+ self.patch_size = net.patch_size
15
+ self.depth_mode = net.depth_mode
16
+ self.conf_mode = net.conf_mode
17
+ self.has_conf = has_conf
18
+ self.mode = mode
19
+ self.classifier_mode = None
20
+ if self.mode == 'pts3d':
21
+ self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
22
+ elif self.mode == 'depth':
23
+ self.proj = nn.Linear(net.dec_embed_dim, (1 + has_conf)*self.patch_size**2)
24
+ elif self.mode == 'classifier':
25
+ self.proj = nn.Linear(net.dec_embed_dim, (1 + has_conf)*self.patch_size**2)
26
+ self.classifier_mode = net.classifier_mode
27
+
28
+ def setup(self, croconet):
29
+ pass
30
+
31
+ def forward(self, decout, img_shape):
32
+ H, W = img_shape
33
+ tokens = decout[-1]
34
+ B, S, D = tokens.shape
35
+
36
+ # extract 3D points
37
+ feat = self.proj(tokens) # B,S,D
38
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
39
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
40
+
41
+ # permute + norm depth
42
+ return postprocess(feat, self.depth_mode, self.conf_mode,self.classifier_mode)
models/heads/postprocess.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def postprocess(out, depth_mode, conf_mode,classifier_mode=None):
4
+ """
5
+ extract 3D points/confidence from prediction head output
6
+ """
7
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,3
8
+ if classifier_mode is None:
9
+ if fmap.shape[-1] == 4:
10
+ res = dict(pointmaps=reg_dense_pts3d(fmap[:, :, :, :-1], mode=depth_mode))
11
+ else:
12
+ res = dict(depths=reg_dense_depth(fmap[:, :, :, 0], mode=depth_mode))
13
+ if conf_mode is not None:
14
+ res['conf_pointmaps'] = reg_dense_conf(fmap[:, :, :, -1], mode=conf_mode)
15
+ else:
16
+ res = dict(classifier=reg_dense_classifier(fmap[:, :, :, 0], mode=classifier_mode))
17
+ if conf_mode is not None:
18
+ res['conf_classifier'] = reg_dense_conf(fmap[:, :, :, 1], mode=conf_mode)
19
+
20
+ return res
21
+
22
+ def reg_dense_classifier(x, mode):
23
+ """
24
+ extract classifier from prediction head output
25
+ """
26
+ mode, vmin, vmax = mode
27
+ #return torch.sigmoid(x)
28
+ return x
29
+
30
+ def reg_dense_depth(x, mode):
31
+ """
32
+ extract depth from prediction head output
33
+ """
34
+ mode, vmin, vmax = mode
35
+ no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
36
+ assert no_bounds
37
+ if mode == 'linear':
38
+ return x
39
+ elif mode == 'square':
40
+ return x.square().clip(min=vmin, max=vmax)
41
+ elif mode == 'exp':
42
+ return torch.exp(x).clip(min=vmin, max=vmax)
43
+ else:
44
+ raise ValueError(f'bad {mode=}')
45
+
46
+ def reg_dense_pts3d(xyz, mode):
47
+ """
48
+ extract 3D points from prediction head output
49
+ """
50
+ mode, vmin, vmax = mode
51
+
52
+ no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
53
+ assert no_bounds
54
+
55
+ if mode == 'linear':
56
+ if no_bounds:
57
+ return xyz # [-inf, +inf]
58
+ return xyz.clip(min=vmin, max=vmax)
59
+
60
+ # distance to origin
61
+ d = xyz.norm(dim=-1, keepdim=True)
62
+ xyz = xyz / d.clip(min=1e-8)
63
+ if mode == 'square':
64
+ return xyz * d.square()
65
+
66
+ if mode == 'exp':
67
+ return xyz * torch.expm1(d)
68
+ raise ValueError(f'bad {mode=}')
69
+
70
+ def reg_dense_conf(x, mode):
71
+ """
72
+ extract confidence from prediction head output
73
+ """
74
+ mode, vmin, vmax = mode
75
+ if mode == 'exp':
76
+ return vmin + x.exp().clip(max=vmax-vmin)
77
+ if mode == 'sigmoid':
78
+ return (vmax - vmin) * torch.sigmoid(x) + vmin
79
+ raise ValueError(f'bad {mode=}')
80
+
models/losses.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bb = breakpoint
2
+ import torch
3
+ import torch.nn as nn
4
+ import copy
5
+ from utils.geometry import normalize_pointcloud
6
+
7
+ class Criterion (nn.Module):
8
+ def __init__(self, criterion=None):
9
+ super().__init__()
10
+ self.criterion = copy.deepcopy(criterion)
11
+
12
+ def get_name(self):
13
+ return f'{type(self).__name__}({self.criterion})'
14
+
15
+ class CrocoLoss (nn.Module):
16
+ def __init__(self,mode='vanilla',eps=1e-4):
17
+ super().__init__()
18
+ self.mode = mode
19
+ def get_name(self):
20
+ return f'CrocoLoss({self.mode})'
21
+
22
+ def forward(self, pred, gt, **kw):
23
+ pred_pts = pred['pointmaps']
24
+ conf = pred['conf']
25
+
26
+ if self.mode == 'vanilla':
27
+ loss = torch.abs(gt-pred_pts)/(torch.exp(conf)) + conf
28
+ elif self.mode == 'bounded_1':
29
+ a=0.25
30
+ b=4.
31
+ conf = (b-a)*torch.sigmoid(conf) + a
32
+ loss = torch.abs(gt-pred_pts)/(conf) + torch.log(conf)
33
+ elif self.mode == 'bounded_2':
34
+ a = 3.0
35
+ b = 3.0
36
+ conf = 2*a * (torch.sigmoid(conf/b)-0.5)
37
+ loss = torch.abs(gt-pred_pts)/torch.exp(conf) + conf
38
+ return loss.mean()
39
+
40
+ class SMDLoss (nn.Module):
41
+ def __init__(self,raw_loss,mode='linear'):
42
+ super().__init__()
43
+ self.mode = mode
44
+ self.raw_loss = raw_loss
45
+ def get_name(self):
46
+ return f'SMDLoss({self.raw_loss},{self.mode})'
47
+
48
+ def forward(self, pred, gt,eps, **kw):
49
+ p_gt = compute_probs(pred,gt,eps=eps)
50
+ # filtering out nan values
51
+ loss = self.raw_loss(p_gt)
52
+ loss_mask = ~torch.isnan(p_gt) & (loss != torch.inf).bool()
53
+ loss = loss[loss_mask]
54
+ return loss.mean()
55
+
56
+ # https://github.com/naver/dust3r/blob/c9e9336a6ba7c1f1873f9295852cea6dffaf770d/dust3r/losses.py#L197
57
+ class ConfLoss (nn.Module):
58
+ """ Weighted regression by learned confidence.
59
+ Assuming the input pixel_loss is a pixel-level regression loss.
60
+
61
+ Principle:
62
+ high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
63
+ low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
64
+
65
+ alpha: hyperparameter
66
+ """
67
+
68
+ def __init__(self, raw_loss, alpha=0.2,skip_conf=False):
69
+ super().__init__()
70
+ assert alpha > 0
71
+ self.alpha = alpha
72
+ self.raw_loss = raw_loss
73
+ self.skip_conf = skip_conf
74
+
75
+ def get_name(self):
76
+ return f'ConfLoss({self.raw_loss})'
77
+
78
+ def get_conf_log(self, x):
79
+ return x, torch.log(x)
80
+
81
+ def forward(self, pred, gt,conf, **kw):
82
+ # compute per-pixel loss
83
+ loss = self.raw_loss(gt, pred, **kw)
84
+ # weight by confidence
85
+ if not self.skip_conf:
86
+ conf, log_conf = self.get_conf_log(conf)
87
+ conf_loss = loss * conf - self.alpha * log_conf
88
+ ## average + nan protection (in case of no valid pixels at all)
89
+ conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
90
+ return conf_loss
91
+ else:
92
+ return loss.mean()
93
+
94
+
95
+ class BCELoss(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def get_name(self):
100
+ return f'BCELoss()'
101
+
102
+ def forward(self, gt, pred):
103
+ # return torch.nn.functional.binary_cross_entropy(pred, gt)
104
+ return torch.nn.functional.binary_cross_entropy_with_logits(pred, gt)
105
+
106
+ class ClassifierLoss(nn.Module):
107
+ def __init__(self,criterion):
108
+ super().__init__()
109
+ self.criterion = criterion
110
+
111
+ def get_name(self):
112
+ return f'ClassifierLoss({self.criterion})'
113
+
114
+ def forward(self, pred, gt):
115
+ return self.criterion(pred, gt)
116
+
117
+ class BaseCriterion(nn.Module):
118
+ def __init__(self, reduction='none'):
119
+ super().__init__()
120
+ self.reduction = reduction
121
+
122
+ class NLLLoss (BaseCriterion):
123
+ """ Negative log likelihood loss """
124
+ def forward(self, pred):
125
+ # assuming the pred is already a log (for stability sake)
126
+ return -pred
127
+ #return -torch.log(pred)
128
+
129
+ class LLoss (BaseCriterion):
130
+ """ L-norm loss
131
+ """
132
+ def forward(self, a, b):
133
+ assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
134
+ dist = self.distance(a, b)
135
+ assert dist.ndim == a.ndim - 1 # one dimension less
136
+ if self.reduction == 'none':
137
+ return dist
138
+ if self.reduction == 'sum':
139
+ return dist.sum()
140
+ if self.reduction == 'mean':
141
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
142
+ raise ValueError(f'bad {self.reduction=} mode')
143
+
144
+ def distance(self, a, b):
145
+ raise NotImplementedError()
146
+
147
+ class L21Loss (LLoss):
148
+ """ Euclidean distance between 3d points """
149
+
150
+ def distance(self, a, b):
151
+ return torch.norm(a - b, dim=-1)
152
+
153
+ L21 = L21Loss()
154
+
155
+ def apply_log_to_norm(xyz):
156
+ d = xyz.norm(dim=-1, keepdim=True)
157
+ xyz = xyz / d.clip(min=1e-8)
158
+ xyz = xyz * torch.log1p(d)
159
+ return xyz
160
+
161
+ class DepthCompletion (Criterion):
162
+ def __init__(self, criterion, classifier_criterion=None,norm_mode='?None', loss_in_log=False,device='cuda',lambda_classifier=1.0):
163
+ super().__init__(criterion)
164
+ self.criterion.reduction = 'none'
165
+ self.loss_in_log = loss_in_log
166
+ self.device = device
167
+ self.lambda_classifier = lambda_classifier
168
+ self.classifier_criterion = classifier_criterion
169
+
170
+ if norm_mode.startswith('?'):
171
+ # do no norm pts from metric scale datasets
172
+ self.norm_all = False
173
+ self.norm_mode = norm_mode[1:]
174
+ else:
175
+ self.norm_all = True
176
+ self.norm_mode = norm_mode
177
+
178
+ def forward(self, pred_dict, gt_dict,**kw):
179
+ gt_depths = gt_dict['depths']
180
+ pred_depths = pred_dict['depths']
181
+ gt_masks = gt_dict['valid_masks']
182
+ if gt_masks.sum() == 0:
183
+ return None
184
+ else:
185
+ gt_depths_masked = gt_depths[gt_masks].view(-1,1)
186
+ pred_depths_masked = pred_depths[gt_masks].view(-1,1)
187
+ # this is a loss on the points on the objects
188
+ loss_dict = {'loss_points':self.criterion(pred_depths_masked, gt_depths_masked,pred_dict['conf_pointmaps'][gt_masks])}
189
+ # loss on predicting a mask for the points on the objects
190
+ if 'classifier' in pred_dict and self.classifier_criterion is not None:
191
+ loss_dict['loss_classifier'] = self.classifier_criterion(pred_dict['classifier'], gt_dict['valid_masks'].float(),pred_dict['conf_classifier'])
192
+ loss_dict['loss'] = loss_dict['loss_points'] + self.lambda_classifier * loss_dict['loss_classifier']
193
+ else:
194
+ loss_dict['loss'] = loss_dict['loss_points']
195
+
196
+ return loss_dict
197
+
198
+
199
+ class RayCompletion (Criterion):
200
+ def __init__(self, criterion, classifier_criterion=None,norm_mode='?None', loss_in_log=False,device='cuda',lambda_classifier=1.0):
201
+ super().__init__(criterion)
202
+ self.criterion.reduction = 'none'
203
+ self.loss_in_log = loss_in_log
204
+ self.device = device
205
+ self.lambda_classifier = lambda_classifier
206
+ self.classifier_criterion = classifier_criterion
207
+
208
+ if norm_mode.startswith('?'):
209
+ # do no norm pts from metric scale datasets
210
+ self.norm_all = False
211
+ self.norm_mode = norm_mode[1:]
212
+ else:
213
+ self.norm_all = True
214
+ self.norm_mode = norm_mode
215
+
216
+ def get_all_pts3d(self, gt_dict, pred_dict):
217
+ gt_pts1 = gt_dict['pointmaps']
218
+ #gt_pts_context = gt_dict['pointmaps_context'][:,0] # we use the first camera given as input for normalization, in our current case that's the only cam
219
+ if 'pointmaps' in pred_dict:
220
+ pr_pts1 = pred_dict['pointmaps']
221
+ else:
222
+ pr_pts1 = None
223
+ mask = gt_dict['valid_masks'].clone()
224
+ # normalize 3d points
225
+ norm_factor = None
226
+
227
+ return gt_pts1, pr_pts1, mask, norm_factor
228
+
229
+ def forward(self, pred_dict, gt_dict, eps=None,**kw):
230
+ gt_pts1, pred_pts1, mask, norm_factor = \
231
+ self.get_all_pts3d(gt_dict, pred_dict, **kw)
232
+ if mask.sum() == 0:
233
+ return None
234
+ else:
235
+ mask_repeated = mask.unsqueeze(-1).repeat(1,1,1,3)
236
+ if norm_factor is not None:
237
+ pred_pts1 = pred_pts1 / norm_factor
238
+ gt_pts1 = gt_pts1 / norm_factor
239
+
240
+ pred_pts1 = pred_pts1[mask_repeated].reshape(-1,3)
241
+ gt_pts1 = gt_pts1[mask_repeated].reshape(-1,3)
242
+
243
+ if self.loss_in_log and self.loss_in_log != 'before':
244
+ # this only make sense when depth_mode == 'exp'
245
+ pred_pts1 = apply_log_to_norm(pred_pts1)
246
+ gt_pts1 = apply_log_to_norm(gt_pts1)
247
+
248
+ # this is a loss on the points on the objects
249
+ loss_dict = {'loss_points':self.criterion(pred_pts1, gt_pts1,pred_dict['conf_pointmaps'][mask])}
250
+ # loss on predicting a mask for the points on the objects
251
+ if 'classifier' in pred_dict and self.classifier_criterion is not None:
252
+ loss_dict['loss_classifier'] = self.classifier_criterion(pred_dict['classifier'], gt_dict['valid_masks'].float(),pred_dict['conf_classifier'])
253
+ loss_dict['loss'] = loss_dict['loss_points'] + self.lambda_classifier * loss_dict['loss_classifier']
254
+ else:
255
+ loss_dict['loss'] = loss_dict['loss_points']
256
+
257
+ return loss_dict
models/pos_embed.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ # --------------------------------------------------------
4
+ # Position embedding utils
5
+ # --------------------------------------------------------
6
+
7
+
8
+ import numpy as np
9
+
10
+ import torch
11
+
12
+ # --------------------------------------------------------
13
+ # 2D sine-cosine position embedding
14
+ # References:
15
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
16
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
17
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
18
+ # --------------------------------------------------------
19
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
20
+ """
21
+ grid_size: int of the grid height and width
22
+ return:
23
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
24
+ """
25
+ grid_h = np.arange(grid_size, dtype=np.float32)
26
+ grid_w = np.arange(grid_size, dtype=np.float32)
27
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
28
+ grid = np.stack(grid, axis=0)
29
+
30
+ grid = grid.reshape([2, 1, grid_size, grid_size])
31
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32
+ if n_cls_token>0:
33
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
34
+ return pos_embed
35
+
36
+
37
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
+ assert embed_dim % 2 == 0
39
+
40
+ # use half of dimensions to encode grid_h
41
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
42
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43
+
44
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
45
+ return emb
46
+
47
+
48
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
49
+ """
50
+ embed_dim: output dimension for each position
51
+ pos: a list of positions to be encoded: size (M,)
52
+ out: (M, D)
53
+ """
54
+ assert embed_dim % 2 == 0
55
+ omega = np.arange(embed_dim // 2, dtype=float)
56
+ omega /= embed_dim / 2.
57
+ omega = 1. / 10000**omega # (D/2,)
58
+
59
+ pos = pos.reshape(-1) # (M,)
60
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
61
+
62
+ emb_sin = np.sin(out) # (M, D/2)
63
+ emb_cos = np.cos(out) # (M, D/2)
64
+
65
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
66
+ return emb
67
+
68
+
69
+ # --------------------------------------------------------
70
+ # Interpolate position embeddings for high-resolution
71
+ # References:
72
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89
+ # only the position tokens are interpolated
90
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92
+ pos_tokens = torch.nn.functional.interpolate(
93
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96
+ checkpoint_model['pos_embed'] = new_pos_embed
97
+
98
+
99
+ #----------------------------------------------------------
100
+ # RoPE2D: RoPE implementation in 2D
101
+ #----------------------------------------------------------
102
+
103
+ try:
104
+ from extensions.curope import cuRoPE2D
105
+ RoPE2D = cuRoPE2D
106
+ except ImportError:
107
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
108
+
109
+ class RoPE2D(torch.nn.Module):
110
+
111
+ def __init__(self, freq=100.0, F0=1.0):
112
+ super().__init__()
113
+ self.base = freq
114
+ self.F0 = F0
115
+ self.cache = {}
116
+
117
+ def get_cos_sin(self, D, seq_len, device, dtype):
118
+ if (D,seq_len,device,dtype) not in self.cache:
119
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
120
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
121
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
122
+ freqs = torch.cat((freqs, freqs), dim=-1)
123
+ cos = freqs.cos() # (Seq, Dim)
124
+ sin = freqs.sin()
125
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
126
+ return self.cache[D,seq_len,device,dtype]
127
+
128
+ @staticmethod
129
+ def rotate_half(x):
130
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
134
+ assert pos1d.ndim==2
135
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
136
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
137
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
138
+
139
+ def forward(self, tokens, positions):
140
+ """
141
+ input:
142
+ * tokens: batch_size x nheads x ntokens x dim
143
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
144
+ output:
145
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
146
+ """
147
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
148
+ D = tokens.size(3) // 2
149
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
150
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
151
+ # split features into two along the feature dimension, and apply rope1d on each half
152
+ y, x = tokens.chunk(2, dim=-1)
153
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
154
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
155
+ tokens = torch.cat((y, x), dim=-1)
156
+ return tokens
models/rayquery.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bb = breakpoint
2
+ import torch
3
+ import torch.nn as nn
4
+ from models.blocks import DecoderBlock, Block, PatchEmbed, PositionGetter
5
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
6
+ from models.losses import *
7
+ from utils.geometry import center_pointmaps, compute_rays
8
+ from models.heads import head_factory
9
+
10
+ def init_weights(m):
11
+ if isinstance(m, nn.Linear):
12
+ # we use xavier_uniform following official JAX ViT:
13
+ torch.nn.init.xavier_uniform_(m.weight)
14
+ if isinstance(m, nn.Linear) and m.bias is not None:
15
+ nn.init.constant_(m.bias, 0)
16
+ elif isinstance(m, nn.LayerNorm):
17
+ if m.bias is not None:
18
+ nn.init.constant_(m.bias, 0)
19
+ if m.weight is not None:
20
+ nn.init.constant_(m.weight, 1.0)
21
+ elif isinstance(m, nn.Parameter):
22
+ nn.init.normal_(m, std=0.02)
23
+
24
+ class RayEncoder(nn.Module):
25
+ def __init__(self,
26
+ dim=256,
27
+ patch_size=8,
28
+ img_size=(128,128),
29
+ depth=3,
30
+ num_heads=4,
31
+ pos_embed='RoPE100',
32
+ ):
33
+ super().__init__()
34
+ self.img_size = img_size
35
+ self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=2, embed_dim=dim)
36
+ self.dim = dim
37
+ if pos_embed.startswith('RoPE'):
38
+ freq = float(pos_embed[len('RoPE'):])
39
+ self.rope = RoPE2D(freq=freq)
40
+ else:
41
+ self.rope = None
42
+ self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
43
+ self.initialize_weights()
44
+
45
+ def initialize_weights(self):
46
+ # patch embed
47
+ self.patch_embed._init_weights()
48
+
49
+ # linears and layer norms
50
+ self.apply(init_weights)
51
+
52
+ def forward(self, rays):
53
+ rays = rays.permute(0,3,1,2)
54
+ rays, pos = self.patch_embed(rays)
55
+ for blk in self.blocks:
56
+ rays = blk(rays, pos)
57
+ return rays, pos
58
+
59
+ class PointmapEncoder(nn.Module):
60
+ def __init__(self,
61
+ dim=256,
62
+ patch_size=8,
63
+ img_size=(128,128),
64
+ depth=3,
65
+ num_heads=4,
66
+ pos_embed='RoPE100',
67
+ ):
68
+ super().__init__()
69
+ self.img_size = img_size
70
+ self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=3, embed_dim=dim)
71
+ self.dim = dim
72
+ self.patch_size = patch_size
73
+
74
+ if pos_embed.startswith('RoPE'):
75
+ freq = float(pos_embed[len('RoPE'):])
76
+ self.rope = RoPE2D(freq=freq)
77
+ else:
78
+ self.rope = None
79
+ self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
80
+ self.masked_token = nn.Parameter(torch.randn(1,1,3))
81
+ self.initialize_weights()
82
+
83
+ def initialize_weights(self):
84
+ # patch embed
85
+ self.patch_embed._init_weights()
86
+
87
+ # linears and layer norms
88
+ self.apply(init_weights)
89
+
90
+ def forward(self, pointmaps,masks=None):
91
+ # replace masked points (not on object) with a learned token
92
+ pointmaps[~masks] = self.masked_token.to(pointmaps.dtype).to(pointmaps.device)
93
+ pointmaps = pointmaps.permute(0,3,1,2)
94
+ pointmaps, pos = self.patch_embed(pointmaps)
95
+
96
+ for blk in self.blocks:
97
+ pointmaps = blk(pointmaps, pos)
98
+ return pointmaps, pos
99
+
100
+ class RayQuery(nn.Module):
101
+ def __init__(self,
102
+ ray_enc=RayEncoder(),
103
+ pointmap_enc=PointmapEncoder(),
104
+ dec_pos_embed='RoPE100',
105
+ decoder_dim=256,
106
+ decoder_depth=3,
107
+ decoder_num_heads=4,
108
+ imshape=(128,128),
109
+ pts_head_type='dpt',
110
+ classifier_head_type='dpt_mask',
111
+ criterion=ConfLoss(L21),
112
+ return_all_blocks=True,
113
+ depth_mode=('exp',-float('inf'),float('inf')),
114
+ conf_mode=('exp',1,float('inf')),
115
+ classifier_mode=('raw',0,1),
116
+ dino_layers=[23],
117
+ ):
118
+ super().__init__()
119
+ self.ray_enc = ray_enc
120
+ self.pointmap_enc = pointmap_enc
121
+ self.dec_depth = decoder_depth
122
+ self.dec_embed_dim = decoder_dim
123
+ self.enc_embed_dim = ray_enc.dim
124
+ self.patch_size = pointmap_enc.patch_size
125
+ self.depth_mode = depth_mode
126
+ self.conf_mode = conf_mode
127
+ self.classifier_mode = classifier_mode
128
+ self.skip_dino = len(dino_layers) == 0
129
+ self.pts_head_type = pts_head_type
130
+ self.classifier_head_type = classifier_head_type
131
+
132
+ if dec_pos_embed.startswith('RoPE'):
133
+ self.dec_pos_embed = RoPE2D(freq=100.0)
134
+ else:
135
+ raise NotImplementedError(f'{dec_pos_embed} not implemented')
136
+ self.decoder_blocks = nn.ModuleList([DecoderBlock(dim=decoder_dim, num_heads=decoder_num_heads,
137
+ rope=self.dec_pos_embed) for _ in range(decoder_depth)])
138
+ self.pts_head = head_factory(pts_head_type, 'pts3d', self, has_conf=True)
139
+
140
+ self.classifier_head = head_factory(classifier_head_type, 'pts3d', self, has_conf=True)
141
+ self.imshape = imshape
142
+ self.criterion = criterion
143
+ self.return_all_blocks = return_all_blocks
144
+
145
+ # dino projection
146
+ self.dino_layers = dino_layers
147
+ self.dino_proj = nn.Linear(1024 * len(dino_layers), decoder_dim)
148
+ self.dino_pos_getter = PositionGetter()
149
+
150
+ self.initialize_weights()
151
+
152
+ def initialize_weights(self):
153
+ self.apply(init_weights)
154
+
155
+ def forward_encoders(self, rays, pointmaps,masks=None):
156
+ # encode rays
157
+ rays, rays_pos = self.ray_enc(rays)
158
+
159
+ # encode pointmaps
160
+ B, H, W, C = pointmaps.shape
161
+ pointmaps = pointmaps.reshape(B,H,W,C) # each pointmap is encoded separately
162
+ pointmaps, pointmaps_pos = self.pointmap_enc(pointmaps,masks=masks)
163
+ new_shape = pointmaps.shape
164
+ pointmaps = pointmaps.reshape(new_shape[0],*new_shape[1:])
165
+ pointmaps_pos = pointmaps_pos[:B]
166
+
167
+ return rays, rays_pos, pointmaps, pointmaps_pos
168
+
169
+ def forward_decoder(self, rays, rays_pos, pointmaps, pointmaps_pos):
170
+ if self.return_all_blocks:
171
+ all_blocks = []
172
+ for blk in self.decoder_blocks:
173
+ rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
174
+ all_blocks.append(rays)
175
+ return all_blocks
176
+ else:
177
+ for blk in self.decoder_blocks:
178
+ rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
179
+ return rays
180
+
181
+ def get_dino_pos(self,dino_features):
182
+ # dino runs on 14x14 patches
183
+ # note: assuming we cropped or resized down!
184
+ dino_H = self.imshape[0]//14
185
+ dino_W = self.imshape[1]//14
186
+ dino_pos = self.dino_pos_getter(dino_features.shape[0],dino_H,dino_W,dino_features.device)
187
+ return dino_pos
188
+
189
+ def forward(self,batch,mode='loss'):
190
+ # prep for encoders
191
+ rays = compute_rays(batch) # we are querying the first camera
192
+ pointmaps_context = batch['input_cams']['pointmaps'] # we are using the other cameras as context
193
+ input_masks = batch['input_cams']['valid_masks']
194
+
195
+ # run the encoders
196
+ rays, rays_pos, pointmaps, pointmaps_pos = self.forward_encoders(rays, pointmaps_context,masks=input_masks)
197
+ ## adding dino features
198
+ if not self.skip_dino:
199
+ dino_features = batch['input_cams']['dino_features']
200
+ dino_features = self.dino_proj(dino_features)
201
+ if len(dino_features.shape) == 4:
202
+ dino_features = dino_features.squeeze(1)
203
+ dino_pos = self.get_dino_pos(dino_features)
204
+ pointmaps = torch.cat([pointmaps,dino_features],dim=1)
205
+ pointmaps_pos = torch.cat([pointmaps_pos,dino_pos],dim=1)
206
+ else:
207
+ dino_features = None
208
+ dino_pos = None
209
+ # decoder
210
+ rays = self.forward_decoder(rays, rays_pos, pointmaps, pointmaps_pos)
211
+ pts_pred_dict = self.pts_head(rays, self.imshape)
212
+ classifier_pred_dict = self.classifier_head(rays, self.imshape)
213
+
214
+ pred_dict = {**pts_pred_dict,**classifier_pred_dict}
215
+ gt_dict = batch['new_cams']
216
+ loss_dict = self.criterion(pred_dict, gt_dict)
217
+
218
+ del rays, rays_pos, pointmaps, pointmaps_pos, dino_features, dino_pos, pointmaps_context, input_masks, pts_pred_dict, classifier_pred_dict
219
+
220
+ if mode == 'loss':
221
+ # delete all the variables that are not needed
222
+ del pred_dict, gt_dict
223
+ return loss_dict
224
+ elif mode == 'viz':
225
+ return pred_dict, gt_dict, loss_dict
226
+ else:
227
+ raise ValueError(f"Invalid mode: {mode}")
readme.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center", documentation will follow later.
2
+
3
+ # RaySt3R: Predicting Novel Depth Maps for Zero-Shot Object Completion
4
+
5
+ <a href="https://arxiv.org/abs/2506.05285"><img src='https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white' alt='arXiv'></a>
6
+ <a href='https://rayst3r.github.io'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
7
+
8
+ </div>
9
+
10
+ <div align="center">
11
+ <img src="assets/overview.png" width="80%" alt="Method overview">
12
+ </div>
13
+
14
+ ## 📚 Citation
15
+ ```bibtex
16
+ @misc{rayst3r,
17
+ title={RaySt3R: Predicting Novel Depth Maps for Zero-Shot Object Completion},
18
+ author={Bardienus P. Duisterhof and Jan Oberst and Bowen Wen and Stan Birchfield and Deva Ramanan and Jeffrey Ichnowski},
19
+ year={2025},
20
+ eprint={2506.05285},
21
+ archivePrefix={arXiv},
22
+ primaryClass={cs.CV},
23
+ url={https://arxiv.org/abs/2506.05285},
24
+ }
25
+ ```
26
+ ## ✅ TO-DOs
27
+
28
+ - [x] Inference code
29
+ - [x] Local gradio demo
30
+ - [ ] Huggingface demo
31
+ - [ ] Docker
32
+ - [ ] Training code
33
+ - [ ] Eval code
34
+ - [ ] ViT-S, No-DINO and Pointmap models
35
+ - [ ] Dataset release
36
+
37
+ # ⚙️ Installation
38
+
39
+ ```bash
40
+ mamba create -n rayst3r python=3.11 cmake=3.14.0
41
+ mamba activate rayst3r
42
+ mamba install pytorch torchvision pytorch-cuda=12.4 -c pytorch -c nvidia # change to your version of cuda
43
+ pip install -r requirements.txt
44
+
45
+ # compile the cuda kernels for RoPE
46
+ cd extensions/curope/
47
+ python setup.py build_ext --inplace
48
+ cd ../../
49
+ ```
50
+
51
+ # 🚀 Usage
52
+
53
+ The expected input for RaySt3R is a folder with the following structure:
54
+
55
+ <pre><code>
56
+ 📁 data_dir/
57
+ ├── cam2world.pt # Camera-to-world transformation (PyTorch tensor), 4x4 - eye(4) if not provided
58
+ ├── depth.png # Depth image, uint16 with max 10 meters
59
+ ├── intrinsics.pt # Camera intrinsics (PyTorch tensor), 3x3
60
+ ├── mask.png # Binary mask image
61
+ └── rgb.png # RGB image
62
+ </code></pre>
63
+
64
+ Note the depth image needs to be saved in uint16, normalized to a 0-10 meters range. We provide an example directory in `example_scene`.
65
+ Run RaySt3R with:
66
+
67
+
68
+ ```bash
69
+ python3 eval_wrapper/eval.py example_scene/
70
+ ```
71
+ This writes a colored point cloud back into the input directory.
72
+
73
+ Optional flags:
74
+ ```bash
75
+ --visualize # Spins up a rerun client to visualize predictions and camera posees
76
+ --run_octmae # Novel views sampled with the OctMAE parameters (see paper)
77
+ --set_conf N # Sets confidence threshold to N
78
+ --n_pred_views # Number of predicted views along each axis in a grid, 5--> 22 views total
79
+ --filter_all_masks # Use all masks, point gets rejected if in background for a single mask
80
+ --tsdf # Fits TSDF to depth maps
81
+ ```
82
+
83
+ # 🧪 Gradio app
84
+
85
+ We also provide a gradio app, which uses <a href="https://wangrc.site/MoGePage/">MoGe</a> and <a href="https://github.com/danielgatis/rembg">Rembg</a> to generate 3D from a single image.
86
+
87
+ Launch it with:
88
+ ```bash
89
+ python app.py
90
+ ```
91
+
92
+ # 🎛️ Parameter Guide
93
+
94
+ Certain applications may benefit from different hyper parameters, here we provide guidance on how to select them.
95
+
96
+ #### 🔁 View Sampling
97
+
98
+ We sample novel views evenly on a cylindrical equal-area projection of the sphere.
99
+ Customize sampling in <a href="eval_wrapper/sample_poses.py">sample_poses.py</a>. Use --n_pred_views to reduce the total number of views, making inference faster and reduce overlap and artifacts.
100
+
101
+ #### 🟢 Confidence Threshold
102
+
103
+ You can set the confidence threshold with the --set_conf threshold. As shown in the paper, a higher threshold generally improves accuracy, reduces edge bleeding but also affects completeness.
104
+
105
+ #### 🧼 RaySt3R Masks
106
+
107
+ On top of what was presented in the paper, we also provide the option to consider all predicted masks for each point. I.e., for any point, if any of the predicted masks classifies them as background the point gets removed.
108
+ In our limited testing this led to cleaner predictions, but it ocasinally carves out crucial parts of geometry.
109
+
110
+ # 🏋️ Training
111
+
112
+ The RaySt3R training command is provided in <a href="xps/train_rayst3r.py">train_rayst3r.py</a>, documentation will follow later.
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ open3d
4
+ Pillow
5
+ pyrender
6
+ rerun
7
+ setuptools
8
+ tqdm
9
+ trimesh
10
+ huggingface-hub
11
+ wandb
12
+ einops
13
+
14
+ # for app.py
15
+ onnxruntime
16
+ gradio
17
+ rembg
18
+ git+https://github.com/microsoft/MoGe.git
utils/augmentations.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from abc import ABC, abstractmethod
5
+ from torchvision.transforms import GaussianBlur
6
+ from utils.batch_prep import compute_pointmaps
7
+ import imgaug as ia
8
+ import imgaug.augmenters as iaa
9
+ import numpy as np
10
+
11
+ class ChangeBright(torch.nn.Module):
12
+ def __init__(self,prob=0.5,mag=[0.5,2.0]):
13
+ super().__init__()
14
+ self.mag = mag
15
+ self.prob = prob
16
+
17
+ def forward(self,rgb):
18
+ #if np.random.uniform()>=self.prob:
19
+ #return rgb
20
+ n = rgb.shape[0]
21
+ apply_aug = np.random.uniform(0,1,size=n) < self.prob
22
+ aug = iaa.MultiplyBrightness(np.random.uniform(self.mag[0],self.mag[1])) #NOTE iaa has bug about deterministic, we sample ourselves
23
+ rgb[apply_aug] = aug(images=rgb[apply_aug])
24
+ return rgb
25
+
26
+ class ChangeContrast(torch.nn.Module):
27
+ def __init__(self,prob=0.5,mag=[0.5,2.0]):
28
+ self.mag = mag
29
+ self.prob = prob
30
+
31
+ def __call__(self,rgb):
32
+ n = rgb.shape[0]
33
+ apply_aug = np.random.uniform(0,1,size=n) < self.prob
34
+
35
+ aug = iaa.GammaContrast(np.random.uniform(self.mag[0],self.mag[1]))
36
+ rgb[apply_aug] = aug(images=rgb[apply_aug])
37
+ return rgb
38
+
39
+ class SaltAndPepper:
40
+ def __init__(self, prob=0.3, ratio=0.1, per_channel=True):
41
+ self.prob = prob
42
+ self.ratio = ratio
43
+ self.per_channel = per_channel
44
+
45
+ def __call__(self, rgb):
46
+ n = rgb.shape[0]
47
+ apply_aug = np.random.uniform(0,1,size=n) < self.prob
48
+ aug = iaa.SaltAndPepper(self.ratio, per_channel=self.per_channel).to_deterministic()
49
+ rgb[apply_aug] = aug(images=rgb[apply_aug])
50
+ return rgb
51
+
52
+ class RGBGaussianNoise:
53
+ def __init__(self, max_noise=10, prob=0.5):
54
+ self.max_noise = max_noise
55
+ self.prob = prob
56
+
57
+ def __call__(self, rgb):
58
+ n = rgb.shape[0]
59
+ apply_aug = np.random.uniform(0,1,size=n) < self.prob
60
+
61
+ shape = rgb.shape
62
+ noise = np.random.normal(0, self.max_noise, size=shape).clip(-self.max_noise, self.max_noise)
63
+ rgb[apply_aug] = (rgb[apply_aug].astype(float) + noise[apply_aug]).clip(0,255).astype(np.uint8)
64
+ return rgb
65
+
66
+ # from https://github.com/mihdalal/manipgen/blob/master/manipgen/utils/obs_utils.py
67
+ class DepthWarping(torch.nn.Module):
68
+ def __init__(self, std=0.5, prob=0.8):
69
+ super().__init__()
70
+ self.std = std
71
+ self.prob = prob
72
+
73
+ def forward(self, depths, device=None):
74
+ if device is None:
75
+ device = depths.device
76
+
77
+ n, _, h, w = depths.shape
78
+
79
+ # Generate Gaussian shifts
80
+ gaussian_shifts = torch.normal(mean=0, std=self.std, size=(n, h, w, 2), device=device).float()
81
+ apply_shifts = torch.rand(n, device=device) < self.prob
82
+ gaussian_shifts[~apply_shifts] = 0.0
83
+
84
+ # Create grid for the original coordinates
85
+ xx = torch.linspace(0, w - 1, w, device=device)
86
+ yy = torch.linspace(0, h - 1, h, device=device)
87
+ xx = xx.unsqueeze(0).repeat(h, 1)
88
+ yy = yy.unsqueeze(1).repeat(1, w)
89
+ grid = torch.stack((xx, yy), 2).unsqueeze(0) # Add batch dimension
90
+
91
+ # Apply Gaussian shifts to the grid
92
+ grid = grid + gaussian_shifts
93
+
94
+ # Normalize grid values to the range [-1, 1] for grid_sample
95
+ grid[..., 0] = (grid[..., 0] / (w - 1)) * 2 - 1
96
+ grid[..., 1] = (grid[..., 1] / (h - 1)) * 2 - 1
97
+
98
+ # Perform the remapping using grid_sample
99
+ depth_interp = F.grid_sample(depths, grid, mode='bilinear', padding_mode='border', align_corners=True)
100
+
101
+ # Remove the batch and channel dimensions
102
+ depth_interp = depth_interp.squeeze(0).squeeze(0)
103
+
104
+ return depth_interp
105
+
106
+ class DepthHoles(torch.nn.Module):
107
+ def __init__(self, prob=0.5, kernel_size_lower=3, kernel_size_upper=27, sigma_lower=1.0,
108
+ sigma_upper=7.0, thresh_lower=0.6, thresh_upper=0.9):
109
+ super().__init__()
110
+ self.prob = prob
111
+ self.kernel_size_lower = kernel_size_lower
112
+ self.kernel_size_upper = kernel_size_upper
113
+ self.sigma_lower = sigma_lower
114
+ self.sigma_upper = sigma_upper
115
+ self.thresh_lower = thresh_lower
116
+ self.thresh_upper = thresh_upper
117
+
118
+ def forward(self, depths, device=None):
119
+ if device is None:
120
+ device = depths.device
121
+
122
+ n, _, h, w = depths.shape
123
+ # generate random noise
124
+ noise = torch.rand(n, 1, h, w, device=device)
125
+
126
+ # apply gaussian blur
127
+ k = random.choice(list(range(self.kernel_size_lower, self.kernel_size_upper+1, 2)))
128
+ noise = GaussianBlur(kernel_size=k, sigma=(self.sigma_lower, self.sigma_upper))(noise)
129
+
130
+ # normalize noise
131
+ noise = (noise - noise.min()) / (noise.max() - noise.min())
132
+
133
+ # apply thresholding
134
+ thresh = torch.rand(n, 1, 1, 1, device=device) * (self.thresh_upper - self.thresh_lower) + self.thresh_lower
135
+ mask = (noise > thresh)
136
+ prob = self.prob
137
+ keep_mask = torch.rand(n, device=device) < prob
138
+ mask[~keep_mask, :] = 0
139
+
140
+ return mask
141
+
142
+ class DepthNoise(torch.nn.Module):
143
+ def __init__(self, std=0.005,prob=1.0):
144
+ super().__init__()
145
+ self.std = std
146
+ self.prob = prob
147
+
148
+ def forward(self, depths, device=None):
149
+ if device is None:
150
+ device = depths.device
151
+
152
+ n, _, h, w = depths.shape
153
+ apply_noise = torch.rand(n, device=device) < self.prob
154
+ noise = torch.randn(n, 1, h, w, device=device) * self.std
155
+ noise[~apply_noise] = 0.0
156
+ return depths + noise
157
+
158
+ class Augmentor(torch.nn.Module):
159
+ def __init__(self, depth_holes=DepthHoles(), depth_warping=DepthWarping(),depth_noise=DepthNoise(),
160
+ rgb_operators=[ChangeBright(),SaltAndPepper(),ChangeContrast(),RGBGaussianNoise()]):
161
+ super().__init__()
162
+ self.depth_holes = depth_holes
163
+ self.depth_warping = depth_warping
164
+ self.depth_noise = depth_noise
165
+ self.rgb_operators = rgb_operators
166
+
167
+ def forward(self, batch):
168
+ input_depths = batch['input_cams']['depths']
169
+ if self.depth_holes.prob > 0:
170
+ masks = self.depth_holes(input_depths)
171
+ batch['input_cams']['valid_masks'][masks] = False
172
+ #if self.depth_warping.prob > 0:
173
+ #input_depths = self.depth_warping(input_depths)
174
+ if self.depth_noise.prob > 0:
175
+ input_depths = self.depth_noise(input_depths)
176
+
177
+ input_rgbs = batch['input_cams']['imgs'].squeeze(1).cpu().numpy() # this is a bit inefficient, but it's ok..
178
+ for op in self.rgb_operators:
179
+ input_rgbs = op(input_rgbs)
180
+ batch['input_cams']['imgs'] = torch.from_numpy(input_rgbs).cuda().unsqueeze(1)
181
+
182
+ batch['input_cams']['depths'] = input_depths
183
+ batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) # now we're doing this twice, but alas
184
+ return batch
utils/batch_prep.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as tvf
3
+
4
+ dino_patch_size = 14
5
+
6
+ def batch_to_device(batch,device='cuda'):
7
+ for key in batch:
8
+ if isinstance(batch[key],torch.Tensor):
9
+ batch[key] = batch[key].to(device)
10
+ elif isinstance(batch[key],dict):
11
+ batch[key] = batch_to_device(batch[key],device)
12
+ return batch
13
+
14
+
15
+ def compute_pointmap(depth: torch.Tensor, intrinsics: torch.Tensor, cam2world: torch.Tensor = None) -> torch.Tensor:
16
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
17
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
18
+ h, w = depth.shape
19
+
20
+ i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
21
+ i = i.to(depth.device)
22
+ j = j.to(depth.device)
23
+
24
+ x_cam = (i - cx) * depth / fx
25
+ y_cam = (j - cy) * depth / fy
26
+
27
+ points_cam = torch.stack([x_cam, y_cam, depth], axis=-1)
28
+
29
+ if cam2world is not None:
30
+ points_cam = torch.matmul(cam2world[:3, :3], points_cam.reshape(-1, 3).T).T + cam2world[:3, 3]
31
+ points_cam = points_cam.reshape(h, w, 3)
32
+
33
+ return points_cam
34
+
35
+ def compute_pointmaps(depths: torch.Tensor, intrinsics: torch.Tensor, cam2worlds: torch.Tensor) -> torch.Tensor:
36
+ pointmaps = []
37
+ depth_shape = depths.shape
38
+ pointmaps_shape = depths.shape + (3,)
39
+ for depth, K, c2w in zip(depths, intrinsics, cam2worlds):
40
+ n_views = depth.shape[0]
41
+ for i in range(n_views):
42
+ pointmaps.append(compute_pointmap(depth[i], K[i],c2w[i]))
43
+ return torch.stack(pointmaps).reshape(pointmaps_shape)
44
+
45
+ def depth_to_metric(depth):
46
+ # depth: shape H x W
47
+ # we want to convert the depth to a metric depth
48
+ depth_max = 10.0
49
+ depth_scaled = depth_max * (depth / 65535.0)
50
+
51
+ return depth_scaled
52
+
53
+ def make_rgb_transform() -> tvf.Compose:
54
+ return tvf.Compose([
55
+ #tvf.ToTensor(),
56
+ #lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
57
+ tvf.Normalize(
58
+ mean=(123.675, 116.28, 103.53),
59
+ std=(58.395, 57.12, 57.375),
60
+ ),
61
+ ])
62
+
63
+ rgb_transform = make_rgb_transform()
64
+
65
+ def compute_dino_and_store_features(dino_model : torch.nn.Module, rgb: torch.Tensor, mask: torch.Tensor,dino_layers: list[int] = None) -> torch.Tensor:
66
+ """Computes the DINO features given an RGB image."""
67
+ rgb = rgb.squeeze(1)
68
+ mask = mask.squeeze(1)
69
+ rgb = rgb.permute(0,3,1,2)
70
+ mask = mask.unsqueeze(1).repeat(1,3,1,1)
71
+ rgb = rgb * mask
72
+
73
+ rgb = rgb.float()
74
+ H, W = rgb.shape[-2:]
75
+ goal_H, goal_W = H//dino_patch_size*dino_patch_size, W//dino_patch_size*dino_patch_size
76
+ resize_transform = tvf.CenterCrop([goal_H, goal_W])
77
+ with torch.no_grad():
78
+ rgb = resize_transform(rgb)
79
+ rgb = rgb_transform(rgb)
80
+ all_feat = dino_model.get_intermediate_layers(rgb, dino_layers)
81
+ dino_feat = torch.cat(all_feat, dim=-1)
82
+ return dino_feat
83
+
84
+
85
+ def prepare_fast_batch(batch,dino_model = None,dino_layers = None):
86
+ # depth to metric
87
+ batch['new_cams']['depths'] = depth_to_metric(batch['new_cams']['depths'])
88
+ batch['input_cams']['depths'] = depth_to_metric(batch['input_cams']['depths'])
89
+
90
+ # compute pointmaps
91
+ batch['new_cams']['pointmaps'] = compute_pointmaps(batch['new_cams']['depths'],batch['new_cams']['Ks'],batch['new_cams']['c2ws'])
92
+ batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws'])
93
+
94
+ # compute dino features
95
+ if dino_model is not None and len(dino_layers) > 0:
96
+ batch['input_cams']['dino_features'] = compute_dino_and_store_features(dino_model,batch['input_cams']['imgs'],batch['input_cams']['valid_masks'],dino_layers)
97
+
98
+ return batch
99
+
100
+
101
+ def normalize_batch(batch,normalize_mode):
102
+ scale_factors = []
103
+ if normalize_mode == 'None':
104
+ pass
105
+ elif normalize_mode == 'median':
106
+ B = batch['input_cams']['valid_masks'].shape[0]
107
+ for b in range(B):
108
+ input_mask = batch['input_cams']['valid_masks'][b]
109
+ depth_median = batch['input_cams']['depths'][b][input_mask].median()
110
+ scale_factor = 1.0 / depth_median
111
+ scale_factors.append(scale_factor)
112
+ batch['input_cams']['depths'][b] = scale_factor * batch['input_cams']['depths'][b]
113
+ batch['input_cams']['pointmaps'][b] = scale_factor * batch['input_cams']['pointmaps'][b]
114
+ batch['input_cams']['c2ws'][b][0,:3,-1] = scale_factor * batch['input_cams']['c2ws'][b][0,:3,-1]
115
+
116
+ batch['new_cams']['depths'][b] = scale_factor * batch['new_cams']['depths'][b]
117
+ batch['new_cams']['pointmaps'][b] = scale_factor * batch['new_cams']['pointmaps'][b]
118
+ batch['new_cams']['c2ws'][b][:,:3,-1] = scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
119
+
120
+ return batch, scale_factors
121
+
122
+ def denormalize_batch(batch,pred,gt,scale_factors):
123
+ B = len(scale_factors)
124
+ n_new_cams = batch['new_cams']['c2ws'].shape[1]
125
+ for b in range(B):
126
+ new_scale_factor = 1.0 / scale_factors[b]
127
+ batch['input_cams']['depths'][b] = new_scale_factor * batch['input_cams']['depths'][b]
128
+ batch['input_cams']['pointmaps'][b] = new_scale_factor * batch['input_cams']['pointmaps'][b]
129
+ batch['input_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['input_cams']['c2ws'][b][:,:3,-1]
130
+ batch['new_cams']['depths'][b] = new_scale_factor * batch['new_cams']['depths'][b]
131
+ batch['new_cams']['pointmaps'][b] = new_scale_factor * batch['new_cams']['pointmaps'][b]
132
+ batch['new_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
133
+
134
+ pred['depths'][b] = new_scale_factor * pred['depths'][b]
135
+
136
+ gt['c2ws'][b][:,:3,-1] = new_scale_factor * gt['c2ws'][b][:,:3,-1]
137
+ gt['depths'][b] = new_scale_factor * gt['depths'][b]
138
+
139
+ gt['pointmaps'][b] = compute_pointmaps(gt['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
140
+ pred['pointmaps'][b] = compute_pointmaps(pred['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
141
+ return batch, pred, gt
utils/collate.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def collate(batch):
4
+ if isinstance(batch[0],dict):
5
+ return {k: collate([d[k] for d in batch]) for k in batch[0].keys()}
6
+ else:
7
+ return torch.stack([torch.stack(t) for t in batch])
utils/eval.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def eval_pred(pred_dict, gt_dict,accuracy_tresh=[0.001,0.01,0.02,0.05,0.1,0.5]):
4
+ pointmaps_pred = pred_dict['pointmaps']
5
+ pointmaps_gt = gt_dict['pointmaps']
6
+ mask = gt_dict['valid_masks'].unsqueeze(-1).repeat(1,1,1,3)
7
+
8
+ points_pred = pointmaps_pred[mask].reshape(-1,3)
9
+ points_gt = pointmaps_gt[mask].reshape(-1,3)
10
+ dists = torch.norm(points_pred - points_gt, dim=1)
11
+ results = {'dist':dists.mean().detach().item()}
12
+ if 'classifier' in pred_dict:
13
+ classifier_pred = (torch.sigmoid(pred_dict['classifier']) > 0.5).bool()
14
+ classifier_gt = gt_dict['valid_masks']
15
+ results['classifier_acc'] = (classifier_pred == classifier_gt).float().mean().detach().item()
16
+
17
+ for tresh in accuracy_tresh:
18
+ acc = (dists < tresh).float().mean()
19
+ results[f'acc_{tresh}'] = acc.detach().item()
20
+ return results
utils/fusion.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018 Andy Zeng
2
+
3
+ import numpy as np
4
+ import torch
5
+ from numba import njit, prange
6
+ from skimage import measure
7
+
8
+ try:
9
+ import pycuda.driver as cuda
10
+ import pycuda.autoinit
11
+ from pycuda.compiler import SourceModule
12
+ FUSION_GPU_MODE = 1
13
+ except Exception as err:
14
+ print('Warning: {}'.format(err))
15
+ print('Failed to import PyCUDA. Running fusion in CPU mode.')
16
+ FUSION_GPU_MODE = 0
17
+
18
+
19
+ class TSDFVolume:
20
+ """Volumetric TSDF Fusion of RGB-D Images.
21
+ """
22
+ def __init__(self, vol_bnds, voxel_size, use_gpu=True):
23
+ """Constructor.
24
+
25
+ Args:
26
+ vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the
27
+ xyz bounds (min/max) in meters.
28
+ voxel_size (float): The volume discretization in meters.
29
+ """
30
+ vol_bnds = np.asarray(vol_bnds)
31
+ assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)."
32
+
33
+ # Define voxel volume parameters
34
+ self._vol_bnds = vol_bnds
35
+ self._voxel_size = float(voxel_size)
36
+ self._trunc_margin = 5 * self._voxel_size # truncation on SDF
37
+ self._color_const = 256 * 256
38
+
39
+ # Adjust volume bounds and ensure C-order contiguous
40
+ self._vol_dim = np.ceil((self._vol_bnds[:,1]-self._vol_bnds[:,0])/self._voxel_size).copy(order='C').astype(int)
41
+ self._vol_bnds[:,1] = self._vol_bnds[:,0]+self._vol_dim*self._voxel_size
42
+ self._vol_origin = self._vol_bnds[:,0].copy(order='C').astype(np.float32)
43
+
44
+ print("Voxel volume size: {} x {} x {} - # points: {:,}".format(
45
+ self._vol_dim[0], self._vol_dim[1], self._vol_dim[2],
46
+ self._vol_dim[0]*self._vol_dim[1]*self._vol_dim[2])
47
+ )
48
+
49
+ # Initialize pointers to voxel volume in CPU memory
50
+ self._tsdf_vol_cpu = np.ones(self._vol_dim).astype(np.float32)
51
+ # for computing the cumulative moving average of observations per voxel
52
+ self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
53
+ self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
54
+
55
+ #self.gpu_mode = False # CPU for debugging!!
56
+ self.gpu_mode = use_gpu and FUSION_GPU_MODE
57
+
58
+ # Copy voxel volumes to GPU
59
+ if self.gpu_mode:
60
+ self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes)
61
+ cuda.memcpy_htod(self._tsdf_vol_gpu,self._tsdf_vol_cpu)
62
+ self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes)
63
+ cuda.memcpy_htod(self._weight_vol_gpu,self._weight_vol_cpu)
64
+ self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes)
65
+ cuda.memcpy_htod(self._color_vol_gpu,self._color_vol_cpu)
66
+
67
+ # Cuda kernel function (C++)
68
+ self._cuda_src_mod = SourceModule("""
69
+ __global__ void integrate(float * tsdf_vol,
70
+ float * weight_vol,
71
+ float * color_vol,
72
+ float * vol_dim,
73
+ float * vol_origin,
74
+ float * cam_intr,
75
+ float * cam_pose,
76
+ float * other_params,
77
+ float * color_im,
78
+ float * depth_im) {
79
+ // Get voxel index
80
+ int gpu_loop_idx = (int) other_params[0];
81
+ int max_threads_per_block = blockDim.x;
82
+ int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x;
83
+ int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x;
84
+ int vol_dim_x = (int) vol_dim[0];
85
+ int vol_dim_y = (int) vol_dim[1];
86
+ int vol_dim_z = (int) vol_dim[2];
87
+ if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z)
88
+ return;
89
+ // Get voxel grid coordinates (note: be careful when casting)
90
+ float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z)));
91
+ float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z));
92
+ float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z);
93
+ // Voxel grid coordinates to world coordinates
94
+ float voxel_size = other_params[1];
95
+ float pt_x = vol_origin[0]+voxel_x*voxel_size;
96
+ float pt_y = vol_origin[1]+voxel_y*voxel_size;
97
+ float pt_z = vol_origin[2]+voxel_z*voxel_size;
98
+ // World coordinates to camera coordinates
99
+ float tmp_pt_x = pt_x-cam_pose[0*4+3];
100
+ float tmp_pt_y = pt_y-cam_pose[1*4+3];
101
+ float tmp_pt_z = pt_z-cam_pose[2*4+3];
102
+ float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z;
103
+ float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z;
104
+ float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z;
105
+ // Camera coordinates to image pixels
106
+ int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]);
107
+ int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]);
108
+ // Skip if outside view frustum
109
+ int im_h = (int) other_params[2];
110
+ int im_w = (int) other_params[3];
111
+ if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0)
112
+ return;
113
+ // Skip invalid depth
114
+ float depth_value = depth_im[pixel_y*im_w+pixel_x];
115
+ if (depth_value == 0)
116
+ return;
117
+ // Integrate TSDF
118
+ float trunc_margin = other_params[4];
119
+ float depth_diff = depth_value-cam_pt_z;
120
+ if (depth_diff < -trunc_margin)
121
+ return;
122
+ float dist = fmin(1.0f,depth_diff/trunc_margin);
123
+ float w_old = weight_vol[voxel_idx];
124
+ float obs_weight = other_params[5];
125
+ float w_new = w_old + obs_weight;
126
+ weight_vol[voxel_idx] = w_new;
127
+ tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new;
128
+ // Integrate color
129
+ float old_color = color_vol[voxel_idx];
130
+ float old_b = floorf(old_color/(256*256));
131
+ float old_g = floorf((old_color-old_b*256*256)/256);
132
+ float old_r = old_color-old_b*256*256-old_g*256;
133
+ float new_color = color_im[pixel_y*im_w+pixel_x];
134
+ float new_b = floorf(new_color/(256*256));
135
+ float new_g = floorf((new_color-new_b*256*256)/256);
136
+ float new_r = new_color-new_b*256*256-new_g*256;
137
+ new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f);
138
+ new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f);
139
+ new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f);
140
+ color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r;
141
+ }""")
142
+
143
+ self._cuda_integrate = self._cuda_src_mod.get_function("integrate")
144
+
145
+ # Determine block/grid size on GPU
146
+ gpu_dev = cuda.Device(0)
147
+ self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK
148
+ n_blocks = int(np.ceil(float(np.prod(self._vol_dim))/float(self._max_gpu_threads_per_block)))
149
+ grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X,int(np.floor(np.cbrt(n_blocks))))
150
+ grid_dim_y = min(gpu_dev.MAX_GRID_DIM_Y,int(np.floor(np.sqrt(n_blocks/grid_dim_x))))
151
+ grid_dim_z = min(gpu_dev.MAX_GRID_DIM_Z,int(np.ceil(float(n_blocks)/float(grid_dim_x*grid_dim_y))))
152
+ self._max_gpu_grid_dim = np.array([grid_dim_x,grid_dim_y,grid_dim_z]).astype(int)
153
+ self._n_gpu_loops = int(np.ceil(float(np.prod(self._vol_dim))/float(np.prod(self._max_gpu_grid_dim)*self._max_gpu_threads_per_block)))
154
+
155
+ else:
156
+ # Get voxel grid coordinates
157
+ xv, yv, zv = np.meshgrid(
158
+ range(self._vol_dim[0]),
159
+ range(self._vol_dim[1]),
160
+ range(self._vol_dim[2]),
161
+ indexing='ij'
162
+ )
163
+ self.vox_coords = np.concatenate([
164
+ xv.reshape(1,-1),
165
+ yv.reshape(1,-1),
166
+ zv.reshape(1,-1)
167
+ ], axis=0).astype(int).T
168
+
169
+ @staticmethod
170
+ @njit(parallel=True)
171
+ def vox2world(vol_origin, vox_coords, vox_size):
172
+ """Convert voxel grid coordinates to world coordinates.
173
+ """
174
+ vol_origin = vol_origin.astype(np.float32)
175
+ vox_coords = vox_coords.astype(np.float32)
176
+ cam_pts = np.empty_like(vox_coords, dtype=np.float32)
177
+ for i in prange(vox_coords.shape[0]):
178
+ for j in range(3):
179
+ cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j])
180
+ return cam_pts
181
+
182
+ @staticmethod
183
+ @njit(parallel=True)
184
+ def cam2pix(cam_pts, intr):
185
+ """Convert camera coordinates to pixel coordinates.
186
+ """
187
+ intr = intr.astype(np.float32)
188
+ fx, fy = intr[0, 0], intr[1, 1]
189
+ cx, cy = intr[0, 2], intr[1, 2]
190
+ pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
191
+ for i in prange(cam_pts.shape[0]):
192
+ pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
193
+ pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
194
+ return pix
195
+
196
+ @staticmethod
197
+ @njit(parallel=True)
198
+ def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight):
199
+ """Integrate the TSDF volume.
200
+ """
201
+ tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32)
202
+ w_new = np.empty_like(w_old, dtype=np.float32)
203
+ for i in prange(len(tsdf_vol)):
204
+ w_new[i] = w_old[i] + obs_weight
205
+ tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i]
206
+ return tsdf_vol_int, w_new
207
+
208
+ def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.,mask=None):
209
+ """Integrate an RGB-D frame into the TSDF volume.
210
+
211
+ Args:
212
+ color_im (ndarray): An RGB image of shape (H, W, 3).
213
+ depth_im (ndarray): A depth image of shape (H, W).
214
+ cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3).
215
+ cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4).
216
+ obs_weight (float): The weight to assign for the current observation. A higher
217
+ value
218
+ """
219
+ im_h, im_w = depth_im.shape
220
+
221
+ # Fold RGB color image into a single channel image
222
+ color_im = color_im.astype(np.float32)
223
+ color_im = np.floor(color_im[...,2]*self._color_const + color_im[...,1]*256 + color_im[...,0])
224
+
225
+ if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel)
226
+ # no mask implemented yet
227
+ for gpu_loop_idx in range(self._n_gpu_loops):
228
+ self._cuda_integrate(self._tsdf_vol_gpu,
229
+ self._weight_vol_gpu,
230
+ self._color_vol_gpu,
231
+ cuda.InOut(self._vol_dim.astype(np.float32)),
232
+ cuda.InOut(self._vol_origin.astype(np.float32)),
233
+ cuda.InOut(cam_intr.reshape(-1).astype(np.float32)),
234
+ cuda.InOut(cam_pose.reshape(-1).astype(np.float32)),
235
+ cuda.InOut(np.asarray([
236
+ gpu_loop_idx,
237
+ self._voxel_size,
238
+ im_h,
239
+ im_w,
240
+ self._trunc_margin,
241
+ obs_weight
242
+ ], np.float32)),
243
+ cuda.InOut(color_im.reshape(-1).astype(np.float32)),
244
+ cuda.InOut(depth_im.reshape(-1).astype(np.float32)),
245
+ block=(self._max_gpu_threads_per_block,1,1),
246
+ grid=(
247
+ int(self._max_gpu_grid_dim[0]),
248
+ int(self._max_gpu_grid_dim[1]),
249
+ int(self._max_gpu_grid_dim[2]),
250
+ )
251
+ )
252
+ else: # CPU mode: integrate voxel volume (vectorized implementation)
253
+ # Convert voxel grid coordinates to pixel coordinates
254
+ cam_pts = self.vox2world(self._vol_origin, self.vox_coords, self._voxel_size)
255
+ cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose))
256
+ pix_z = cam_pts[:, 2]
257
+ pix = self.cam2pix(cam_pts, cam_intr)
258
+ pix_x, pix_y = pix[:, 0], pix[:, 1]
259
+
260
+ # Eliminate pixels outside view frustum
261
+ valid_pix = np.logical_and(pix_x >= 0,
262
+ np.logical_and(pix_x < im_w,
263
+ np.logical_and(pix_y >= 0,
264
+ np.logical_and(pix_y < im_h,
265
+ pix_z > 0))))
266
+ if mask is not None:
267
+ mask_queries = mask[pix_y[valid_pix],pix_x[valid_pix]]
268
+ valid_pix[valid_pix] = np.logical_and(valid_pix[valid_pix],mask_queries)
269
+
270
+ depth_val = np.zeros(pix_x.shape)
271
+ depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]]
272
+
273
+ # Integrate TSDF
274
+ depth_diff = depth_val - pix_z
275
+ valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin)
276
+ dist = np.minimum(1, depth_diff / self._trunc_margin)
277
+ valid_vox_x = self.vox_coords[valid_pts, 0]
278
+ valid_vox_y = self.vox_coords[valid_pts, 1]
279
+ valid_vox_z = self.vox_coords[valid_pts, 2]
280
+ w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
281
+ tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
282
+ valid_dist = dist[valid_pts]
283
+ tsdf_vol_new, w_new = self.integrate_tsdf(tsdf_vals, valid_dist, w_old, obs_weight)
284
+ self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
285
+ self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new
286
+
287
+ # Integrate color
288
+ old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
289
+ old_b = np.floor(old_color / self._color_const)
290
+ old_g = np.floor((old_color-old_b*self._color_const)/256)
291
+ old_r = old_color - old_b*self._color_const - old_g*256
292
+ new_color = color_im[pix_y[valid_pts],pix_x[valid_pts]]
293
+ new_b = np.floor(new_color / self._color_const)
294
+ new_g = np.floor((new_color - new_b*self._color_const) /256)
295
+ new_r = new_color - new_b*self._color_const - new_g*256
296
+ new_b = np.minimum(255., np.round((w_old*old_b + obs_weight*new_b) / w_new))
297
+ new_g = np.minimum(255., np.round((w_old*old_g + obs_weight*new_g) / w_new))
298
+ new_r = np.minimum(255., np.round((w_old*old_r + obs_weight*new_r) / w_new))
299
+ self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = new_b*self._color_const + new_g*256 + new_r
300
+
301
+ def get_volume(self):
302
+ if self.gpu_mode:
303
+ cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu)
304
+ cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu)
305
+ return self._tsdf_vol_cpu, self._color_vol_cpu
306
+
307
+ def get_point_cloud(self):
308
+ """Extract a point cloud from the voxel volume.
309
+ """
310
+ tsdf_vol, color_vol = self.get_volume()
311
+
312
+ # Marching cubes
313
+ verts = measure.marching_cubes(tsdf_vol, level=0, method='lewiner')[0]
314
+ verts_ind = np.round(verts).astype(int)
315
+ verts = verts*self._voxel_size + self._vol_origin
316
+
317
+ # Get vertex colors
318
+ rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
319
+ colors_b = np.floor(rgb_vals / self._color_const)
320
+ colors_g = np.floor((rgb_vals - colors_b*self._color_const) / 256)
321
+ colors_r = rgb_vals - colors_b*self._color_const - colors_g*256
322
+ colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
323
+ colors = colors.astype(np.uint8)
324
+
325
+ pc = np.hstack([verts, colors])
326
+ return pc
327
+
328
+ def get_mesh(self):
329
+ """Compute a mesh from the voxel volume using marching cubes.
330
+ """
331
+ tsdf_vol, color_vol = self.get_volume()
332
+
333
+ # Marching cubes
334
+ verts, faces, norms, vals = measure.marching_cubes(tsdf_vol, level=0, method='lewiner')
335
+ verts_ind = np.round(verts).astype(int)
336
+ verts = verts*self._voxel_size+self._vol_origin # voxel grid coordinates to world coordinates
337
+
338
+ # Get vertex colors
339
+ rgb_vals = color_vol[verts_ind[:,0], verts_ind[:,1], verts_ind[:,2]]
340
+ colors_b = np.floor(rgb_vals/self._color_const)
341
+ colors_g = np.floor((rgb_vals-colors_b*self._color_const)/256)
342
+ colors_r = rgb_vals-colors_b*self._color_const-colors_g*256
343
+ colors = np.floor(np.asarray([colors_r,colors_g,colors_b])).T
344
+ colors = colors.astype(np.uint8)
345
+ return verts, faces, norms, colors
346
+
347
+
348
+ def rigid_transform(xyz, transform):
349
+ """Applies a rigid transform to an (N, 3) pointcloud.
350
+ """
351
+ xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
352
+ xyz_t_h = np.dot(transform, xyz_h.T).T
353
+ return xyz_t_h[:, :3]
354
+
355
+
356
+ def get_view_frustum(depth_im, cam_intr, cam_pose):
357
+ """Get corners of 3D camera view frustum of depth image
358
+ """
359
+ im_h = depth_im.shape[0]
360
+ im_w = depth_im.shape[1]
361
+ max_depth = np.max(depth_im)
362
+ view_frust_pts = np.array([
363
+ (np.array([0,0,0,im_w,im_w])-cam_intr[0,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[0,0],
364
+ (np.array([0,0,im_h,0,im_h])-cam_intr[1,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[1,1],
365
+ np.array([0,max_depth,max_depth,max_depth,max_depth])
366
+ ])
367
+ view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T
368
+ return view_frust_pts
369
+
370
+
371
+ def meshwrite(filename, verts, faces, norms, colors):
372
+ """Save a 3D mesh to a polygon .ply file.
373
+ """
374
+ # Write header
375
+ ply_file = open(filename,'w')
376
+ ply_file.write("ply\n")
377
+ ply_file.write("format ascii 1.0\n")
378
+ ply_file.write("element vertex %d\n"%(verts.shape[0]))
379
+ ply_file.write("property float x\n")
380
+ ply_file.write("property float y\n")
381
+ ply_file.write("property float z\n")
382
+ ply_file.write("property float nx\n")
383
+ ply_file.write("property float ny\n")
384
+ ply_file.write("property float nz\n")
385
+ ply_file.write("property uchar red\n")
386
+ ply_file.write("property uchar green\n")
387
+ ply_file.write("property uchar blue\n")
388
+ ply_file.write("element face %d\n"%(faces.shape[0]))
389
+ ply_file.write("property list uchar int vertex_index\n")
390
+ ply_file.write("end_header\n")
391
+
392
+ # Write vertex list
393
+ for i in range(verts.shape[0]):
394
+ ply_file.write("%f %f %f %f %f %f %d %d %d\n"%(
395
+ verts[i,0], verts[i,1], verts[i,2],
396
+ norms[i,0], norms[i,1], norms[i,2],
397
+ colors[i,0], colors[i,1], colors[i,2],
398
+ ))
399
+
400
+ # Write face list
401
+ for i in range(faces.shape[0]):
402
+ ply_file.write("3 %d %d %d\n"%(faces[i,0], faces[i,1], faces[i,2]))
403
+
404
+ ply_file.close()
405
+
406
+
407
+ def pcwrite(filename, xyzrgb):
408
+ """Save a point cloud to a polygon .ply file.
409
+ """
410
+ xyz = xyzrgb[:, :3]
411
+ rgb = xyzrgb[:, 3:].astype(np.uint8)
412
+
413
+ # Write header
414
+ ply_file = open(filename,'w')
415
+ ply_file.write("ply\n")
416
+ ply_file.write("format ascii 1.0\n")
417
+ ply_file.write("element vertex %d\n"%(xyz.shape[0]))
418
+ ply_file.write("property float x\n")
419
+ ply_file.write("property float y\n")
420
+ ply_file.write("property float z\n")
421
+ ply_file.write("property uchar red\n")
422
+ ply_file.write("property uchar green\n")
423
+ ply_file.write("property uchar blue\n")
424
+ ply_file.write("end_header\n")
425
+
426
+ # Write vertex list
427
+ for i in range(xyz.shape[0]):
428
+ ply_file.write("%f %f %f %d %d %d\n"%(
429
+ xyz[i, 0], xyz[i, 1], xyz[i, 2],
430
+ rgb[i, 0], rgb[i, 1], rgb[i, 2],
431
+ ))
432
+
433
+ def get_vol_bds(pred_depths : torch.Tensor, pred_c2ws : torch.Tensor, pred_intr : torch.Tensor):
434
+ n_views = pred_depths.shape[0]
435
+ vol_bnds = np.zeros((3,2))
436
+
437
+ for i in range(n_views):
438
+ intr = pred_intr[i].cpu().numpy()
439
+ c2w = pred_c2ws[i].cpu().numpy()
440
+ depth = pred_depths[i].cpu().numpy()
441
+ view_frust_pts = get_view_frustum(depth, intr, c2w)
442
+ vol_bnds[:,0] = np.minimum(vol_bnds[:,0], np.amin(view_frust_pts, axis=1))
443
+ vol_bnds[:,1] = np.maximum(vol_bnds[:,1], np.amax(view_frust_pts, axis=1))
444
+
445
+ return vol_bnds
446
+
447
+ def fuse_batch(pred_dict: dict, gt_dict: dict, batch:dict,voxel_size: float = 0.02):
448
+ pred_depths = pred_dict['pointmaps'][...,-1] # depth here is just z, assuming the predicted point map is in camera frame
449
+ pred_c2ws = batch['new_cams']['c2ws']
450
+ pred_intr = batch['new_cams']['Ks']
451
+ pred_masks = batch['new_cams']['valid_masks']
452
+ B = pred_depths.shape[0]
453
+ n_views = pred_depths.shape[1]
454
+
455
+ meshes = []
456
+ for i in range(B):
457
+ intrs = pred_intr[i]
458
+ c2ws = pred_c2ws[i]
459
+ depths = pred_depths[i]
460
+ vol_bnds = get_vol_bds(depths, c2ws, intrs)
461
+ tsdf_vol = TSDFVolume(vol_bnds, voxel_size=voxel_size)
462
+ masks = pred_masks[i]
463
+
464
+ for j in range(n_views):
465
+ intr = intrs[j]
466
+ c2w = c2ws[j]
467
+ depth = depths[j]
468
+ mask = masks[j]
469
+ depth[~mask] = 0
470
+ img = torch.zeros_like(depth,dtype=torch.uint8).unsqueeze(-1).repeat(1,1,3)
471
+ img[:,:,-1] = 255
472
+ tsdf_vol.integrate(img.cpu().numpy(), depth.cpu().numpy(), intr.cpu().numpy(), c2w.cpu().numpy(), obs_weight=1.)
473
+
474
+ verts, faces, norms, colors = tsdf_vol.get_mesh()
475
+ meshes.append(dict(verts=verts, faces=faces, norms=norms, colors=colors))
476
+ return meshes
utils/geometry.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import copy
4
+ from utils.utils import invalid_to_nans, invalid_to_zeros
5
+
6
+ def compute_pointmap(depth, cam2w, intrinsics):
7
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
8
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
9
+ h, w = depth.shape
10
+
11
+ i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
12
+
13
+ x_cam = (i - cx) * depth / fx
14
+ y_cam = (j - cy) * depth / fy
15
+
16
+ points_cam = np.stack([x_cam, y_cam, depth], axis=-1)
17
+ points_world = np.dot(cam2w[:3, :3], points_cam.reshape(-1, 3).T).T + cam2w[:3, 3]
18
+ points_world = points_world.reshape(h, w, 3)
19
+
20
+ return points_world
21
+
22
+ def invert_poses(raw_poses):
23
+ poses = copy.deepcopy(raw_poses)
24
+ original_shape = poses.shape
25
+ poses = poses.reshape(-1, 4, 4)
26
+ R = copy.deepcopy(poses[:, :3, :3])
27
+ t = copy.deepcopy(poses[:, :3, 3])
28
+ poses[:, :3, :3] = R.transpose(1, 2)
29
+ poses[:, :3, 3] = torch.bmm(-R.transpose(1, 2), t.unsqueeze(-1)).squeeze(-1)
30
+ poses = poses.reshape(*original_shape)
31
+ return poses
32
+
33
+ def center_pointmaps_set(dict,w2cs):
34
+ swap_dim = False
35
+ if dict["pointmaps"].shape[1] == 3:
36
+ swap_dim = True
37
+ dict["pointmaps"] = dict["pointmaps"].transpose(1,-1)
38
+
39
+ original_shape = dict["pointmaps"].shape
40
+ device = dict["pointmaps"].device
41
+ B = original_shape[0]
42
+
43
+ # recompute pointmaps in camera frame
44
+ pointmaps = dict["pointmaps"]
45
+ pointmaps_h = torch.cat([pointmaps,torch.ones(pointmaps.shape[:-1]+(1,)).to(device)],dim=-1)
46
+ pointmaps_h = pointmaps_h.reshape(B,-1,4)
47
+ pointmaps_recentered_h = torch.bmm(w2cs,pointmaps_h.transpose(1,2)).transpose(1,2)
48
+ pointmaps_recentered = pointmaps_recentered_h[...,:3]/pointmaps_recentered_h[...,3:4]
49
+ pointmaps_recentered = pointmaps_recentered.reshape(*original_shape)
50
+
51
+ # recompute c2ws
52
+ if "c2ws" in dict:
53
+ c2ws_recentered = torch.bmm(w2cs,dict["c2ws"].reshape(-1,4,4))
54
+ c2ws_recentered = c2ws_recentered.reshape(dict["c2ws"].shape)
55
+ dict["c2ws"] = c2ws_recentered
56
+
57
+ # assign to dict
58
+ dict["pointmaps"] = pointmaps_recentered
59
+ if swap_dim:
60
+ dict["pointmaps"] = dict["pointmaps"].transpose(1,-1)
61
+ return dict
62
+
63
+ def center_pointmaps(batch):
64
+ original_poses = batch["new_cams"]["c2ws"] # assuming first camera is the one we want to predict
65
+ w2cs = invert_poses(batch["new_cams"]["c2ws"])
66
+
67
+ batch["new_cams"] = center_pointmaps_set(batch["new_cams"],w2cs)
68
+ batch["input_cams"] = center_pointmaps_set(batch["input_cams"],w2cs)
69
+ batch["original_poses"] = original_poses
70
+ return batch
71
+
72
+
73
+ def uncenter_pointmaps(pred,gt,batch):
74
+ original_poses = batch["original_poses"]
75
+
76
+ batch["new_cams"] = center_pointmaps_set(batch["new_cams"],original_poses)
77
+ batch["input_cams"] = center_pointmaps_set(batch["input_cams"],original_poses)
78
+
79
+ #gt = center_pointmaps_set(gt,original_poses)
80
+ #pred = center_pointmaps_set(pred,original_poses)
81
+ return pred, gt, batch
82
+
83
+ def compute_rays(batch):
84
+ h, w = batch["new_cams"]["pointmaps"].shape[-3:-1]
85
+ B = batch["new_cams"]["pointmaps"].shape[0]
86
+ device = batch["new_cams"]["pointmaps"].device
87
+ Ks = batch["new_cams"]["Ks"]
88
+ i_s, j_s = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
89
+ i_s, j_s = torch.tensor(i_s).repeat(B,1,1).to(device), torch.tensor(j_s).repeat(B,1,1).to(device)
90
+
91
+ f_x = Ks[:,0,0].reshape(-1,1,1)
92
+ f_y = Ks[:,1,1].reshape(-1,1,1)
93
+ c_x = Ks[:,0,2].reshape(-1,1,1)
94
+ c_y = Ks[:,1,2].reshape(-1,1,1)
95
+
96
+ # compute rays with z=1
97
+ x_cam = (i_s - c_x) / f_x
98
+ y_cam = (j_s - c_y) / f_y
99
+ rays = torch.cat([x_cam.unsqueeze(-1),y_cam.unsqueeze(-1)],dim=-1)
100
+ return rays
101
+
102
+ def normalize_pointcloud(pts1, pts2=None, norm_mode='avg_dis', valid1=None, valid2=None, valid3=None, ret_factor=False,pts3=None):
103
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
104
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
105
+ norm_mode, dis_mode = norm_mode.split('_')
106
+
107
+ if norm_mode == 'avg':
108
+ # gather all points together (joint normalization)
109
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
110
+ nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
111
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
112
+ if pts3 is not None:
113
+ nan_pts3, nnz3 = invalid_to_zeros(pts3, valid3, ndim=3)
114
+ all_pts = torch.cat((all_pts, nan_pts3), dim=1)
115
+ nnz1 += nnz3
116
+ # compute distance to origin
117
+ all_dis = all_pts.norm(dim=-1)
118
+ if dis_mode == 'dis':
119
+ pass # do nothing
120
+ elif dis_mode == 'log1p':
121
+ all_dis = torch.log1p(all_dis)
122
+ elif dis_mode == 'warp-log1p':
123
+ # actually warp input points before normalizing them
124
+ log_dis = torch.log1p(all_dis)
125
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
126
+ H1, W1 = pts1.shape[1:-1]
127
+ pts1 = pts1 * warp_factor[:,:W1*H1].view(-1,H1,W1,1)
128
+ if pts2 is not None:
129
+ H2, W2 = pts2.shape[1:-1]
130
+ pts2 = pts2 * warp_factor[:,W1*H1:].view(-1,H2,W2,1)
131
+ all_dis = log_dis # this is their true distance afterwards
132
+ else:
133
+ raise ValueError(f'bad {dis_mode=}')
134
+
135
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
136
+ else:
137
+ # gather all points together (joint normalization)
138
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
139
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
140
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
141
+
142
+ # compute distance to origin
143
+ all_dis = all_pts.norm(dim=-1)
144
+
145
+ if norm_mode == 'avg':
146
+ norm_factor = all_dis.nanmean(dim=1)
147
+ elif norm_mode == 'median':
148
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
149
+ elif norm_mode == 'sqrt':
150
+ norm_factor = all_dis.sqrt().nanmean(dim=1)**2
151
+ else:
152
+ raise ValueError(f'bad {norm_mode=}')
153
+
154
+ norm_factor = norm_factor.clip(min=1e-8)
155
+ while norm_factor.ndim < pts1.ndim:
156
+ norm_factor.unsqueeze_(-1)
157
+
158
+ res = (pts1 / norm_factor,)
159
+ if pts2 is not None:
160
+ res = res + (pts2 / norm_factor,)
161
+ if pts3 is not None:
162
+ res = res + (pts3 / norm_factor,)
163
+ if ret_factor:
164
+ res = res + (norm_factor,)
165
+ return res
166
+
167
+ def compute_pointmap_torch(depth, cam2w, intrinsics,device='cuda'):
168
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
169
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
170
+ h, w = depth.shape
171
+
172
+ #i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
173
+ i, j = torch.meshgrid(torch.arange(w).to(device), torch.arange(h).to(device), indexing='xy')
174
+ x_cam = (i - cx) * depth / fx
175
+ y_cam = (j - cy) * depth / fy
176
+
177
+ points_cam = torch.stack([x_cam, y_cam, depth], dim=-1)
178
+ points_world = (cam2w[:3, :3] @ points_cam.reshape(-1, 3).T).T + cam2w[:3, 3]
179
+ points_world = points_world.reshape(h, w, 3)
180
+
181
+ return points_world
182
+
183
+ def depth2pts(depths, Ks):
184
+ """
185
+ Convert depth map to 3D points
186
+ """
187
+ device = depths.device
188
+ B = depths.shape[0]
189
+ pts = []
190
+ for b in range(B):
191
+ depth_b = depths[b]
192
+ K = Ks[b]
193
+ pts.append(compute_pointmap_torch(depth_b,torch.eye(4).to(device), K,device))
194
+ pts = torch.stack(pts, dim=0)
195
+ return pts
utils/misc.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from pathlib import Path
4
+ import torch
5
+ import torch.distributed as dist
6
+ import numpy as np
7
+ import math
8
+ import socket
9
+ # source: https://github.com/LTH14/mar/blob/main/util/misc.py
10
+
11
+ def prep_torch():
12
+ cpu_cores = get_cpu_cores()
13
+ torch.set_num_threads(cpu_cores) # intra-op threads (e.g., matrix ops)
14
+ torch.set_num_interop_threads(cpu_cores) # inter-op parallelism
15
+
16
+ os.environ["OMP_NUM_THREADS"] = str(cpu_cores)
17
+ os.environ["MKL_NUM_THREADS"] = str(cpu_cores)
18
+ os.environ["OPENBLAS_NUM_THREADS"] = str(cpu_cores)
19
+
20
+ def get_cpu_cores():
21
+ hostname = socket.gethostname()
22
+ if "bridges2" in hostname:
23
+ return int(os.environ["SLURM_JOB_CPUS_PER_NODE"])
24
+ else:
25
+ try:
26
+ with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us", "r") as f:
27
+ quota = int(f.read().strip())
28
+ with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us", "r") as f:
29
+ period = int(f.read().strip())
30
+ if quota > 0:
31
+ return max(1, quota // period)
32
+ except Exception as e:
33
+ return os.cpu_count()
34
+
35
+ def setup_distributed():
36
+ dist.init_process_group(backend='nccl')
37
+ # Get the rank of the current process
38
+ rank = int(os.environ.get('RANK'))
39
+ world_size = int(os.environ.get('WORLD_SIZE'))
40
+ local_rank = int(os.environ.get('LOCAL_RANK'))
41
+ torch.cuda.set_device(local_rank)
42
+ return rank, world_size, local_rank
43
+
44
+ def is_dist_avail_and_initialized():
45
+ if not dist.is_available():
46
+ return False
47
+ if not dist.is_initialized():
48
+ return False
49
+ return True
50
+
51
+ def get_rank():
52
+ if not is_dist_avail_and_initialized():
53
+ return 0
54
+ return dist.get_rank()
55
+
56
+ def is_main_process():
57
+ return get_rank() == 0
58
+
59
+ def get_world_size():
60
+ if not is_dist_avail_and_initialized():
61
+ return 1
62
+ return dist.get_world_size()
63
+
64
+ def save_on_master(*args, **kwargs):
65
+ if is_main_process():
66
+ torch.save(*args, **kwargs)
67
+
68
+ def save_model(args, epoch, model, optimizer, ema_params=None, epoch_name=None):
69
+ if epoch_name is None:
70
+ epoch_name = str(epoch)
71
+
72
+ output_dir = Path(args.logdir)
73
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
74
+
75
+ if ema_params is not None:
76
+ ema_state_dict = copy.deepcopy(model.state_dict())
77
+ for i, (name, _value) in enumerate(model.named_parameters()):
78
+ assert name in ema_state_dict
79
+ ema_state_dict[name] = ema_params[i]
80
+ else:
81
+ ema_state_dict = None
82
+
83
+ to_save = {
84
+ 'model': model.state_dict(),
85
+ 'optimizer': optimizer.state_dict(),
86
+ 'epoch': epoch,
87
+ 'args': args,
88
+ 'model_ema': ema_state_dict,
89
+ }
90
+
91
+ save_on_master(to_save, checkpoint_path)
92
+
93
+ def adjust_learning_rate(optimizer, epoch, args):
94
+ """Decay the learning rate with half-cycle cosine after warmup"""
95
+ if epoch < args.warmup_epochs:
96
+ lr = args.lr * epoch / args.warmup_epochs
97
+ else:
98
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
99
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.n_epochs - args.warmup_epochs)))
100
+ for param_group in optimizer.param_groups:
101
+ if "lr_scale" in param_group:
102
+ param_group["lr"] = lr * param_group["lr_scale"]
103
+ else:
104
+ param_group["lr"] = lr
105
+
106
+ return lr
107
+
108
+
109
+ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
110
+ decay = []
111
+ no_decay = []
112
+ for name, param in model.named_parameters():
113
+ if not param.requires_grad:
114
+ continue # frozen weights
115
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
116
+ no_decay.append(param) # no weight decay on bias, norm and diffloss
117
+ else:
118
+ decay.append(param)
119
+ return [
120
+ {'params': no_decay, 'weight_decay': 0.},
121
+ {'params': decay, 'weight_decay': weight_decay}]
122
+
utils/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def to_tensor(x,dtype=torch.float64):
5
+ if isinstance(x, torch.Tensor):
6
+ return x.to(dtype)
7
+ elif isinstance(x, np.ndarray):
8
+ return torch.from_numpy(x.copy()).to(dtype)
9
+ else:
10
+ raise ValueError(f"Unsupported type: {type(x)}")
11
+
12
+ def to_numpy(x):
13
+ if isinstance(x, torch.Tensor):
14
+ return x.detach().cpu().numpy()
15
+ elif isinstance(x, np.ndarray):
16
+ return x
17
+ else:
18
+ raise ValueError(f"Unsupported type: {type(x)}")
19
+
20
+ def invalid_to_nans( arr, valid_mask, ndim=999 ):
21
+ if valid_mask is not None:
22
+ arr = arr.clone()
23
+ arr[~valid_mask] = float('nan')
24
+ if arr.ndim > ndim:
25
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
26
+ return arr
27
+
28
+ def invalid_to_zeros( arr, valid_mask, ndim=999 ):
29
+ if valid_mask is not None:
30
+ arr = arr.clone()
31
+ arr[~valid_mask] = 0
32
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
33
+ else:
34
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
35
+ if arr.ndim > ndim:
36
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
37
+ return arr, nnz
38
+
39
+ def scenes_to_batch(scenes,repeat=None):
40
+ batch = {}
41
+ n_cams = None
42
+
43
+ if 'new_cams' in scenes:
44
+ n_cams = scenes['new_cams']['depths'].shape[1]
45
+ batch['new_cams'], n_cams = scenes_to_batch(scenes['new_cams'])
46
+ batch['input_cams'],_ = scenes_to_batch(scenes['input_cams'],repeat=n_cams)
47
+ else:
48
+ for key in scenes.keys():
49
+ shape = scenes[key].shape
50
+ if len(shape) > 3 :
51
+ n_cams = shape[1]
52
+ if repeat is not None:
53
+ # repeat the 2nd dimension by repeat times to also have the inputs repeated in the batch
54
+ repeat_dims = (1,) * len(shape) # (1,1,1,...) for all dimensions
55
+ repeat_dims = list(repeat_dims)
56
+ repeat_dims[1] = repeat
57
+ batch[key] = scenes[key].repeat(*repeat_dims)
58
+ batch[key] = batch[key].reshape(-1, *shape[2:])
59
+ else:
60
+ batch[key] = scenes[key].reshape(-1, *shape[2:])
61
+ elif key == 'dino_features':
62
+ repeat_shape = (repeat,) + (1,) * (len(shape) - 1)
63
+ batch[key] = scenes[key].repeat(*repeat_shape)
64
+ else:
65
+ batch[key] = scenes[key]
66
+ return batch, n_cams
67
+
68
+ def dict_to_scenes(input_dict,n_cams):
69
+ scenes = {}
70
+ for key in input_dict.keys():
71
+ if isinstance(input_dict[key],dict):
72
+ scenes[key] = dict_to_scenes(input_dict[key],n_cams)
73
+ else:
74
+ scenes[key] = input_dict[key].reshape(-1, n_cams, *input_dict[key].shape[1:])
75
+ return scenes
76
+
77
+ def batch_to_scenes(pred,gt,batch,n_cams):
78
+ # pred
79
+ batch = dict_to_scenes(batch,n_cams)
80
+ pred = dict_to_scenes(pred,n_cams)
81
+ gt = dict_to_scenes(gt,n_cams)
82
+ return pred, gt, batch
utils/viz.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bb = breakpoint
2
+ import torch
3
+ import numpy as np
4
+ from utils.utils import to_tensor, to_numpy
5
+ import open3d as o3d
6
+ import rerun as rr
7
+
8
+ OPENCV2OPENGL = (1,-1,-1,1)
9
+
10
+ def pts_to_opengl(pts):
11
+ return pts*OPENCV2OPENGL[:3]
12
+
13
+ def save_pointmaps(data,path='debug',view=False,color='novelty',frustrum_scale=20):
14
+ # debug function to save points to a ply file
15
+ import open3d as o3d
16
+ pointmaps = data['pointmaps']
17
+ B = pointmaps.shape[0]
18
+ W, H = pointmaps.shape[-3:-1]
19
+ n_cams = data['c2ws'].shape[1]
20
+ geometries = []
21
+ for b in range(B):
22
+ geometry_b = []
23
+ points = torch.cat([p.flatten(start_dim=0,end_dim=1) for p in pointmaps[b]],dim=0)
24
+ if view:
25
+ pcd = o3d.geometry.PointCloud()
26
+ pcd.points = o3d.utility.Vector3dVector(to_numpy(points))
27
+ if color == 'novelty':
28
+ colors = torch.ones_like(points)
29
+ pts_p_cam = W*H
30
+ # make all novel points red
31
+ colors[pts_p_cam:,1:]*=0.1
32
+
33
+ # make all points from first camera blue
34
+ colors[:pts_p_cam,0]*=0.1
35
+ colors[:pts_p_cam,2]*=0.1
36
+ colors*=255.0
37
+
38
+ else:
39
+ colors = torch.cat([p.flatten(start_dim=0,end_dim=1) for p in data['imgs'][b]],dim=0)
40
+ pcd.colors = o3d.utility.Vector3dVector(to_numpy(colors)/255.0)
41
+ geometry_b.append(pcd)
42
+ origin = o3d.geometry.TriangleMesh.create_coordinate_frame(
43
+ size=10, origin=[0,0,0])
44
+ geometry_b.append(origin)
45
+ for i in range(n_cams):
46
+ K = data['Ks'][b,i].cpu().numpy()
47
+ K = o3d.camera.PinholeCameraIntrinsic(W,H,K)
48
+ P = data['c2ws'][b,i].cpu().numpy()
49
+ cam_frame = o3d.geometry.LineSet.create_camera_visualization(intrinsic=K,extrinsic=P,scale=frustrum_scale)
50
+ geometry_b.append(cam_frame)
51
+ o3d.visualization.draw_geometries(geometry_b)
52
+
53
+ # add point at the origin
54
+ o3d.io.write_point_cloud(f"{path}_{b}.ply", pcd)
55
+ breakpoint()
56
+ geometries.append(geometry_b)
57
+ return geometries
58
+
59
+ def just_load_viz(pred_dict,gt_dict,batch,name='just_load_viz',addr='localhost:9000',fused_meshes=None,n_points=None):
60
+ rr.init(name)
61
+ rr.connect(addr)
62
+ rr.set_time_seconds("stable_time", 0)
63
+
64
+ context_views = batch['input_cams']['pointmaps']
65
+ context_rgbs = batch['input_cams']['imgs']
66
+ gt_pred_views = gt_dict['pointmaps']
67
+ pred_views = pred_dict['pointmaps']
68
+
69
+ # FIX this weird shape
70
+ pred_masks = batch['new_cams']['valid_masks']
71
+ context_masks = batch['input_cams']['valid_masks']
72
+
73
+ B = batch['new_cams']['pointmaps'].shape[0]
74
+ W,H = context_views.shape[-3:-1]
75
+ n_pred_cams = pred_views.shape[1]
76
+
77
+ for b in range(B):
78
+ rr.set_time_seconds("stable_time", b)
79
+ # Set world transform to identity (normal origin)
80
+ rr.log("world", rr.Transform3D(translation=[0, 0, 0], mat3x3=np.eye(3)))
81
+ ## show context views
82
+ context_rgb = to_numpy(context_rgbs[b])
83
+
84
+ for i in range(n_pred_cams):
85
+ if 'conf_pointmaps' in pred_dict:
86
+ conf_pts = pred_dict['conf_pointmaps'][b,i]
87
+
88
+ #print(f"view {i} mean conf: {mean_conf}, std conf: {std_conf}")
89
+ conf_pts = (conf_pts - conf_pts.min())/(conf_pts.max() - conf_pts.min())
90
+ conf_pts = to_numpy(conf_pts)
91
+ rr.log(f"view_{i}/pred_conf", rr.Image(conf_pts))
92
+ if pred_masks[b,i].sum() == 0:
93
+ continue
94
+ if gt_pred_views is not None:
95
+ gt_pred_pts = gt_pred_views[b,i][pred_masks[b,i]]
96
+ gt_pred_pts = to_numpy(gt_pred_pts)
97
+ else:
98
+ gt_pred_pts = None
99
+
100
+ # red is color for gt points
101
+ if gt_pred_pts is not None:
102
+ color = np.array([1,0,0])
103
+ colors = np.ones_like(gt_pred_pts)
104
+ colors[:,0] = color[0]
105
+ colors[:,1] = color[1]
106
+ colors[:,2] = color[2]
107
+ rr.log(
108
+ f"world/new_views_gt/view_{i}", rr.Points3D(gt_pred_pts,colors=colors)
109
+ )
110
+ # green is color for pred points
111
+ pred_pts = pred_views[b,i][pred_masks[b,i]]
112
+ pred_pts = to_numpy(pred_pts)
113
+
114
+ depth = pred_views[b,i][:,:,2]
115
+ depth -= depth[pred_masks[b,i]].min()
116
+ depth[~pred_masks[b,i]] = 0
117
+ depth /= depth.max()
118
+ depth = to_numpy(depth)
119
+ rr.log(f"world/new_views_pred/view_{i}/image", rr.Image(depth))
120
+
121
+ if 'classifier' in pred_dict:
122
+ classifier = (pred_dict['classifier'][b,i] > 0.0).float() # this is assuming the classifier is a sigmoid output
123
+ classifier = to_numpy(classifier)
124
+ rr.log(f"view_{i}/pred_mask", rr.Image(classifier))
125
+
126
+ color = np.array([0,1,0])
127
+ colors = np.ones_like(pred_pts)
128
+ colors[:,0] = color[0]
129
+ colors[:,1] = color[1]
130
+ colors[:,2] = color[2]
131
+ if n_points is None:
132
+ rr.log(
133
+ f"world/new_views_pred/view_{i}/pred_points", rr.Points3D(pred_pts,colors=colors)
134
+ )
135
+ else:
136
+ # randomly sample n_points from pred_pts
137
+ n_points = min(n_points, pred_pts.shape[0])
138
+ inds = np.random.choice(pred_pts.shape[0], n_points, replace=False)
139
+ rr.log(
140
+ f"world/new_views_pred/view_{i}/pred_points", rr.Points3D(pred_pts[inds],colors=colors[inds])
141
+ )
142
+
143
+ K = batch['new_cams']['Ks'][b,i].cpu().numpy()
144
+ P = batch['new_cams']['c2ws'][b,i].cpu().numpy()
145
+ P = np.linalg.inv(P)
146
+ rr.log(f"world/new_views_pred/view_{i}", rr.Transform3D(translation=P[:3,3], mat3x3=P[:3,:3], from_parent=True))
147
+
148
+ rr.log(f"world/new_views_gt/view_{i}", rr.Transform3D(translation=P[:3,3], mat3x3=P[:3,:3], from_parent=True))
149
+
150
+ if 'classifier' in pred_dict:
151
+ classifier = gt_dict['valid_masks'][b,i].float()
152
+ classifier = to_numpy(classifier)
153
+ rr.log(f"view_{i}/gt_mask", rr.Image(classifier))
154
+
155
+ rr.log(
156
+ f"world/new_views_pred/view_{i}/image",
157
+ rr.Pinhole(
158
+ resolution=[H, W],
159
+ focal_length=[K[0,0], K[1,1]],
160
+ principal_point=[K[0,2], K[1,2]],
161
+ ),
162
+ )
163
+
164
+ rr.log(f"world/new_views_pred/view_{i}/image", rr.Image(to_numpy(pred_masks[b,i].float())))
165
+ n_input_cams = context_masks.shape[1]
166
+
167
+ for i in range(n_input_cams):
168
+ context_pts = context_views[b][i][context_masks[b][i]]
169
+ context_pts = to_numpy(context_pts)
170
+ context_pts_rgb = context_rgbs[b][i][context_masks[b][i]]
171
+ context_pts_rgb = to_numpy(context_pts_rgb)
172
+
173
+ # depth imgs
174
+ #context_depths = batch['input_cams']['depths'][b][i]
175
+ #context_depths = (context_depths / context_depths.max() * 255.0).clamp(0,255)
176
+ #context_depths = to_numpy(context_depths).astype(np.uint8)
177
+ rr.log(
178
+ f"world/context_views/view_{i}_points", rr.Points3D(context_pts,colors=(context_pts_rgb/255.0))
179
+ )
180
+
181
+ K = batch['input_cams']['Ks'][b,i].cpu().numpy()
182
+ P = batch['input_cams']['c2ws'][b,i].cpu().numpy()
183
+ P = np.linalg.inv(P)
184
+ rr.log(f"world/context_views/view_{i}", rr.Transform3D(translation=P[:3,3], mat3x3=P[:3,:3], from_parent=True))
185
+
186
+ rr.log(
187
+ f"world/context_views/view_{i}/image",
188
+ rr.Pinhole(
189
+ resolution=[H, W],
190
+ focal_length=[K[0,0], K[1,1]],
191
+ principal_point=[K[0,2], K[1,2]],
192
+ ),
193
+ )
194
+ context_rgb_i = context_rgb[i]
195
+ rr.log(
196
+ f"world/context_views/view_{i}/image", rr.Image(context_rgb_i)
197
+ )
198
+
199
+ rr.log(
200
+ f"world/context_camera_{i}/mask", rr.Image(to_numpy(context_masks[b,i].float()))
201
+ )
202
+ if fused_meshes is not None:
203
+ rr.log(f"world/fused_mesh", rr.Mesh3D(vertex_positions=fused_meshes[b]['verts'], vertex_normals=fused_meshes[b]['norms'], vertex_colors=fused_meshes[b]['colors'], triangle_indices=fused_meshes[b]['faces']))
204
+
205
+
xps/train_rayst3r.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import socket
3
+ import os
4
+ # Add the current working directory to the Python path
5
+ current_dir = os.getcwd()
6
+ sys.path.append(current_dir)
7
+ from xps.util import *
8
+
9
+ root_log_dir = "logs"
10
+ n_views = 2
11
+ dataset_size = -1
12
+
13
+ imshape_input = (480,640)
14
+ imshape_output = (480,640)
15
+ render_size = (480,640)
16
+
17
+ preload_train = False
18
+ data_dirs = ["/home/jovyan/shared/bduister/data/processed/","/home/jovyan/shared/bduister/data-2/processed/"]
19
+ dino_features = [4,11,17,23]
20
+ datasets = ['fp_gso','octmae']
21
+ prefetch_dino = False
22
+ normalize_mode = 'median'
23
+ #start_from = "checkpoints/gso_conf.pth"
24
+ start_from = None
25
+
26
+ noise_std = 0.005
27
+ view_select_mode = "new_zoom"
28
+ rendered_views_mode = "always"
29
+ dataset_train = f"GenericLoader(size={dataset_size},seed=747,dir={repr(data_dirs)},split='train',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
30
+ +f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
31
+ dataset_test = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
32
+ +f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
33
+ dataset_just_load = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
34
+ +f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
35
+
36
+ augmentor = "Augmentor()"
37
+
38
+ patch_size = 16
39
+ save_every = 1
40
+
41
+ vit="base"
42
+ if vit == "debug":
43
+ enc_dim = 128
44
+ dec_dim = 128
45
+ n_heads = 4
46
+ enc_depth = 4
47
+ dec_depth = 4
48
+ head_n_layers = 1
49
+ head_dim = 128
50
+ lr = 3e-4
51
+ batch_size = 20
52
+ blr = 1.5e-4
53
+ elif vit == "debug_2":
54
+ enc_dim = 512
55
+ dec_dim = 512
56
+ n_heads = 4
57
+ enc_depth = 4
58
+ dec_depth = 10
59
+ head_n_layers = 1
60
+ head_dim = 128
61
+ blr = 1.5e-4
62
+ batch_size = 18
63
+ elif vit == "small":
64
+ enc_dim = 384
65
+ dec_dim = 384
66
+ n_heads = 6
67
+ enc_depth = 12
68
+ dec_depth = 12
69
+ head_n_layers = 1
70
+ head_dim = 128
71
+ batch_size = 6
72
+ blr = 1.5e-4
73
+ elif vit == "base":
74
+ enc_dim = 768
75
+ dec_dim = 768
76
+ n_heads = 12
77
+ enc_depth = 4
78
+ dec_depth = 12
79
+ head_n_layers = 1
80
+ head_dim = 128
81
+ batch_size = 10
82
+ blr = 1.5e-4
83
+
84
+ lambda_classifier = 0.1
85
+ for skip_conf_points in [False]:
86
+ skip_conf_mask = True
87
+ model = f"RayQuery(ray_enc=RayEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \
88
+ f"pointmap_enc=PointmapEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \
89
+ f"dino_layers={dino_features}," + \
90
+ f"pts_head_type='dpt_depth'," + \
91
+ f"classifier_head_type='dpt_mask'," + \
92
+ f"decoder_dim={dec_dim},decoder_depth={dec_depth},decoder_num_heads={n_heads},imshape={render_size}," + \
93
+ f"criterion=DepthCompletion(ConfLoss(L21,skip_conf={skip_conf_points}),ConfLoss(ClassifierLoss(BCELoss()),skip_conf={skip_conf_mask}),lambda_classifier={lambda_classifier}),return_all_blocks=True)"
94
+
95
+ key = f"conf_points_{skip_conf_points==False}"
96
+ key = gen_key(key)
97
+ logdir = os.path.join(root_log_dir,key)
98
+ resume=logdir
99
+ wandb_run_name=key
100
+ os.makedirs(logdir,exist_ok=True)
101
+
102
+ n_epochs = 20
103
+ eval_every = 1
104
+ max_norm = -1
105
+ OMP_NUM_THREADS=16
106
+ warmup_epochs = 1
107
+
108
+ executable = f"OMP_NUM_THREADS={OMP_NUM_THREADS} torchrun --nnodes 1 --nproc_per_node $(python -c 'import torch; print(torch.cuda.device_count())') --master_port $((RANDOM%500+29000)) main.py"
109
+ #executable = f"python main.py"
110
+ if '--just_load' in sys.argv:
111
+ batch_size = 5
112
+ command = f"{executable} --{dataset_train=} --{dataset_test=} --{dataset_just_load=} --{logdir=} --{resume=} --{model=} --{batch_size=} --{normalize_mode=} --{augmentor=}"
113
+ else:
114
+ command = f"{executable} --{dataset_train=} --{dataset_test=} --{logdir=} --{n_epochs=} --{resume=} --{normalize_mode=} --{augmentor=} --{warmup_epochs=}"
115
+ command += f" --{model=} --{eval_every=} --{batch_size=} --{save_every=} --{max_norm=}"
116
+ command += f" --{blr=}"
117
+ if start_from is not None:
118
+ command += f" --{start_from=}"
119
+ if not '--no_wandb' in sys.argv:
120
+ command += f" --wandb_project=3dcomplete " + \
121
+ f"--{wandb_run_name=}"
122
+
123
+ if len(sys.argv) > 1:
124
+ for arg in sys.argv[1:]:
125
+ if not '--no_wandb' in arg:
126
+ command += f" {arg}"
127
+ print(command)
xps/util.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+
4
+ def gen_key(raw_key):
5
+ # concat the raw_key with the file name that this function is called from
6
+ current_frame = inspect.currentframe()
7
+ # Get the caller's frame (the frame that called this function)
8
+ caller_frame = current_frame.f_back
9
+ # Extract the filename from the caller's frame
10
+ caller_file = caller_frame.f_code.co_filename
11
+ caller_file = os.path.basename(caller_file).replace(".py","")
12
+ return f"{caller_file}_{raw_key}"