Commit
·
70d1188
0
Parent(s):
init
Browse files- .gitignore +13 -0
- LICENSE +109 -0
- app.py +199 -0
- datasets/generic_loader.py +168 -0
- engine.py +139 -0
- eval_wrapper/eval.py +425 -0
- eval_wrapper/eval_utils.py +125 -0
- eval_wrapper/sample_poses.py +100 -0
- example_scene/cam2world.pt +0 -0
- example_scene/intrinsics.pt +0 -0
- extensions/curope/__init__.py +4 -0
- extensions/curope/curope.cpp +69 -0
- extensions/curope/curope.egg-info/PKG-INFO +10 -0
- extensions/curope/curope.egg-info/SOURCES.txt +9 -0
- extensions/curope/curope.egg-info/dependency_links.txt +1 -0
- extensions/curope/curope.egg-info/top_level.txt +1 -0
- extensions/curope/curope2d.py +43 -0
- extensions/curope/kernels.cu +108 -0
- extensions/curope/setup.py +34 -0
- input/cam2world.pt +0 -0
- input/intrinsics.pt +0 -0
- main.py +198 -0
- models/blocks.py +235 -0
- models/heads/__init__.py +26 -0
- models/heads/dpt_head.py +582 -0
- models/heads/linear_head.py +42 -0
- models/heads/postprocess.py +80 -0
- models/losses.py +257 -0
- models/pos_embed.py +156 -0
- models/rayquery.py +227 -0
- readme.md +112 -0
- requirements.txt +18 -0
- utils/augmentations.py +184 -0
- utils/batch_prep.py +141 -0
- utils/collate.py +7 -0
- utils/eval.py +20 -0
- utils/fusion.py +476 -0
- utils/geometry.py +195 -0
- utils/misc.py +122 -0
- utils/utils.py +82 -0
- utils/viz.py +205 -0
- xps/train_rayst3r.py +127 -0
- xps/util.py +12 -0
.gitignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.out
|
2 |
+
slurm/
|
3 |
+
*.pyc
|
4 |
+
*.png
|
5 |
+
!assets/*.png
|
6 |
+
*.mtl
|
7 |
+
*.obj
|
8 |
+
*.ply
|
9 |
+
*.pth
|
10 |
+
**/build/**
|
11 |
+
*.so
|
12 |
+
wandb**
|
13 |
+
logs
|
LICENSE
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RaySt3R
|
2 |
+
SOFTWARE LICENSE AGREEMENT
|
3 |
+
ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH
|
4 |
+
USE ONLY
|
5 |
+
BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO
|
6 |
+
THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH
|
7 |
+
THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE.
|
8 |
+
|
9 |
+
This is a license agreement ("Agreement") between your academic institution or non-
|
10 |
+
profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie
|
11 |
+
|
12 |
+
Mellon University (called "Licensor" in this Agreement). All rights not specifically
|
13 |
+
granted to you in this Agreement are reserved for Licensor.
|
14 |
+
RESERVATION OF OWNERSHIP AND GRANT OF LICENSE:
|
15 |
+
Licensor retains exclusive ownership of any copy of the Software (as defined below)
|
16 |
+
licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive,
|
17 |
+
non-transferable license to use the Software for noncommercial research purposes,
|
18 |
+
without the right to sublicense, pursuant to the terms and conditions of this Agreement.
|
19 |
+
As used in this Agreement, the term "Software" means (i) the actual copy of all or any
|
20 |
+
portion of code for program routines made accessible to Licensee by Licensor pursuant to
|
21 |
+
this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder
|
22 |
+
or subsequently supplied by Licensor, including all or any file structures, programming
|
23 |
+
instructions, user interfaces and screen formats and sequences as well as any and all
|
24 |
+
documentation and instructions related to it, and (ii) all or any derivatives and/or
|
25 |
+
modifications created or made by You to any of the items specified in (i).
|
26 |
+
CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to
|
27 |
+
Licensor, and as such, Licensee agrees to receive all such materials in confidence and use
|
28 |
+
the Software only in accordance with the terms of this Agreement. Licensee agrees to
|
29 |
+
use reasonable effort to protect the Software from unauthorized use, reproduction,
|
30 |
+
distribution, or publication.
|
31 |
+
COPYRIGHT: The Software is owned by Licensor and is protected by United
|
32 |
+
States copyright laws and applicable international treaties and/or conventions.
|
33 |
+
PERMITTED USES: The Software may be used for your own noncommercial internal
|
34 |
+
research purposes. You understand and agree that Licensor is not obligated to implement
|
35 |
+
any suggestions and/or feedback you might provide regarding the Software, but to the
|
36 |
+
extent Licensor does so, you are not entitled to any compensation related thereto.
|
37 |
+
DERIVATIVES: You may create derivatives of or make modifications to the Software,
|
38 |
+
however, You agree that all and any such derivatives and modifications will be owned by
|
39 |
+
Licensor and become a part of the Software licensed to You under this Agreement. You
|
40 |
+
may only use such derivatives and modifications for your own noncommercial internal
|
41 |
+
research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. You must provide to Licensor one copy
|
42 |
+
of all such derivatives and modifications in a recognized electronic format by way of
|
43 |
+
electronic mail sent to Bardienus Pieter Duisterhof at [email protected]
|
44 |
+
within thirty (30) days of the publication date of any publication that relates to any such
|
45 |
+
derivatives or modifications. You understand that Licensor is not obligated to distribute
|
46 |
+
or otherwise make available any derivatives or modifications provided by You.
|
47 |
+
BACKUPS: If Licensee is an organization, it may make that number of copies of the
|
48 |
+
Software necessary for internal noncommercial use at a single site within its organization
|
49 |
+
provided that all information appearing in or on the original labels, including the
|
50 |
+
copyright and trademark notices are copied onto the labels of the copies.
|
51 |
+
USES NOT PERMITTED: You may not distribute, copy or use the Software except as
|
52 |
+
explicitly permitted herein. Licensee has not been granted any trademark license as part
|
53 |
+
of this Agreement and may not use the name or mark "RaySt3R" "Carnegie Mellon" or any renditions thereof without the prior written
|
54 |
+
permission of Licensor.
|
55 |
+
You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part,
|
56 |
+
or provide third parties access to prior or present versions (or any parts thereof) of the
|
57 |
+
Software.
|
58 |
+
ASSIGNMENT: You may not assign this Agreement or your rights hereunder without
|
59 |
+
the prior written consent of Licensor. Any attempted assignment without such consent
|
60 |
+
shall be null and void.
|
61 |
+
TERM: The term of the license granted by this Agreement is from Licensee's acceptance
|
62 |
+
of this Agreement by clicking "I Agree" below or by using the Software until terminated
|
63 |
+
as provided below.
|
64 |
+
The Agreement automatically terminates without notice if you fail to comply with any
|
65 |
+
provision of this Agreement. Licensee may terminate this Agreement by ceasing using
|
66 |
+
the Software. Upon any termination of this Agreement, Licensee will delete any and all
|
67 |
+
copies of the Software. You agree that all provisions which operate to protect the
|
68 |
+
proprietary rights of Licensor shall remain in force should breach occur and that the
|
69 |
+
obligation of confidentiality described in this Agreement is binding in perpetuity and, as
|
70 |
+
such, survives the term of the Agreement.
|
71 |
+
FEE: Provided Licensee abides completely by the terms and conditions of this
|
72 |
+
Agreement, there is no fee due to Licensor for Licensee's use of the Software in
|
73 |
+
accordance with this Agreement.
|
74 |
+
DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS"
|
75 |
+
WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF
|
76 |
+
PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR
|
77 |
+
USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK
|
78 |
+
|
79 |
+
RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND
|
80 |
+
RELATED MATERIALS.
|
81 |
+
SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is
|
82 |
+
provided as part of this Agreement.
|
83 |
+
EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent
|
84 |
+
permitted under applicable law, Licensor shall not be liable for direct, indirect, special,
|
85 |
+
incidental, or consequential damages or lost profits related to Licensee's use of and/or
|
86 |
+
inability to use the Software, even if Licensor is advised of the possibility of such
|
87 |
+
damage.
|
88 |
+
EXPORT REGULATION: You agree to comply with any and all applicable U.S. export
|
89 |
+
control laws, regulations, and/or other laws related to the embargoes and sanction
|
90 |
+
programs administered by the U.S. Office of Foreign Assets Control. You may not export
|
91 |
+
or re-export the technology with individuals or companies on the U.S. Department of
|
92 |
+
Commerce, Department of State or Department of Treasury denied party lists
|
93 |
+
https://www.trade.gov/consolidated-screening-list . You represent and warrant that
|
94 |
+
Licensee is not an individual or company listed on such denied party lists.
|
95 |
+
SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid,
|
96 |
+
illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the
|
97 |
+
validity, legality and enforceability of the remaining provisions shall not in any way be
|
98 |
+
affected or impaired thereby.
|
99 |
+
NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or
|
100 |
+
remedy under this Agreement shall be construed as a waiver of any future or other
|
101 |
+
exercise of such right or remedy by Licensor.
|
102 |
+
GOVERNING LAW: This Agreement shall be construed and enforced in accordance
|
103 |
+
with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws
|
104 |
+
principles. You consent to the personal jurisdiction of the courts of this County and
|
105 |
+
waive their rights to venue outside of Allegheny County, Pennsylvania.
|
106 |
+
ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole
|
107 |
+
and entire agreement between Licensee and Licensor as to the matter set forth herein and
|
108 |
+
supersedes any previous agreements, understandings, and arrangements between the
|
109 |
+
parties relating hereto.
|
app.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import rembg
|
5 |
+
import trimesh
|
6 |
+
from moge.model.v1 import MoGeModel
|
7 |
+
from utils.geometry import compute_pointmap
|
8 |
+
import os, shutil
|
9 |
+
import cv2
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from PIL import Image
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from eval_wrapper.eval import EvalWrapper, eval_scene
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
outdir = "/tmp/rayst3r"
|
17 |
+
|
18 |
+
# loading all necessary models
|
19 |
+
print("Loading DINOv2 model")
|
20 |
+
dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
|
21 |
+
dino_model.eval()
|
22 |
+
dino_model.to("cuda")
|
23 |
+
|
24 |
+
print("Loading MoGe model")
|
25 |
+
device = torch.device("cuda")
|
26 |
+
# Load the model from huggingface hub (or load from local).
|
27 |
+
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device)
|
28 |
+
|
29 |
+
print("Loading RaySt3R model")
|
30 |
+
rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
|
31 |
+
rayst3r_model = EvalWrapper(rayst3r_checkpoint)
|
32 |
+
|
33 |
+
def depth2uint16(depth):
|
34 |
+
return depth * torch.iinfo(torch.uint16).max / 10.0 # threshold is in m, convert to uint16 value
|
35 |
+
|
36 |
+
def save_tensor_as_png(tensor: torch.Tensor, path: str, dtype: torch.dtype | None = None):
|
37 |
+
if dtype is None:
|
38 |
+
dtype = tensor.dtype
|
39 |
+
Image.fromarray(tensor.to(dtype).cpu().numpy()).save(path)
|
40 |
+
|
41 |
+
def colorize_points_with_turbo_all_dims(points, method='norm',cmap='turbo'):
|
42 |
+
"""
|
43 |
+
Assigns colors to 3D points using the 'turbo' colormap based on a scalar computed from all 3 dimensions.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
points (np.ndarray): (N, 3) array of 3D points.
|
47 |
+
method (str): Method for reducing 3D point to scalar. Options: 'norm', 'pca'.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
np.ndarray: (N, 3) RGB colors in [0, 1].
|
51 |
+
"""
|
52 |
+
assert points.shape[1] == 3, "Input must be of shape (N, 3)"
|
53 |
+
|
54 |
+
if method == 'norm':
|
55 |
+
scalar = np.linalg.norm(points, axis=1)
|
56 |
+
elif method == 'pca':
|
57 |
+
# Project onto first principal component
|
58 |
+
mean = points.mean(axis=0)
|
59 |
+
centered = points - mean
|
60 |
+
u, s, vh = np.linalg.svd(centered, full_matrices=False)
|
61 |
+
scalar = centered @ vh[0] # Project onto first principal axis
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Unknown method '{method}'")
|
64 |
+
|
65 |
+
# Normalize scalar to [0, 1]
|
66 |
+
scalar_min, scalar_max = scalar.min(), scalar.max()
|
67 |
+
normalized = (scalar - scalar_min) / (scalar_max - scalar_min + 1e-8)
|
68 |
+
|
69 |
+
# Apply turbo colormap
|
70 |
+
cmap = plt.colormaps.get_cmap(cmap)
|
71 |
+
colors = cmap(normalized)[:, :3] # Drop alpha
|
72 |
+
|
73 |
+
return colors
|
74 |
+
|
75 |
+
def prep_for_rayst3r(img,depth_dict,mask):
|
76 |
+
H, W = img.shape[:2]
|
77 |
+
intrinsics = depth_dict["intrinsics"].detach().cpu()
|
78 |
+
intrinsics[0] *= W
|
79 |
+
intrinsics[1] *= H
|
80 |
+
|
81 |
+
input_dir = os.path.join(outdir, "input")
|
82 |
+
if os.path.exists(input_dir):
|
83 |
+
shutil.rmtree(input_dir)
|
84 |
+
os.makedirs(input_dir, exist_ok=True)
|
85 |
+
# save intrinsics
|
86 |
+
torch.save(intrinsics, os.path.join(input_dir, "intrinsics.pt"))
|
87 |
+
|
88 |
+
# save depth
|
89 |
+
depth = depth_dict["depth"].cpu()
|
90 |
+
depth = depth2uint16(depth)
|
91 |
+
save_tensor_as_png(depth, os.path.join(input_dir, "depth.png"),dtype=torch.uint16)
|
92 |
+
|
93 |
+
# save mask as bool
|
94 |
+
save_tensor_as_png(torch.from_numpy(mask).bool(), os.path.join(input_dir, "mask.png"),dtype=torch.bool)
|
95 |
+
# save image
|
96 |
+
save_tensor_as_png(torch.from_numpy(img), os.path.join(input_dir, "rgb.png"))
|
97 |
+
|
98 |
+
def rayst3r_to_glb(img,depth_dict,mask,max_total_points=10e6,rotated=False):
|
99 |
+
prep_for_rayst3r(img,depth_dict,mask)
|
100 |
+
rayst3r_points = eval_scene(rayst3r_model,os.path.join(outdir, "input"),do_filter_all_masks=True,dino_model=dino_model).cpu()
|
101 |
+
|
102 |
+
# subsample points
|
103 |
+
n_points = min(max_total_points,rayst3r_points.shape[0])
|
104 |
+
rayst3r_points = rayst3r_points[torch.randperm(rayst3r_points.shape[0])[:n_points]].numpy()
|
105 |
+
|
106 |
+
rayst3r_points[:,1] = -rayst3r_points[:,1]
|
107 |
+
rayst3r_points[:,2] = -rayst3r_points[:,2]
|
108 |
+
|
109 |
+
# make all points red
|
110 |
+
colors = colorize_points_with_turbo_all_dims(rayst3r_points)
|
111 |
+
|
112 |
+
# load the input glb
|
113 |
+
scene = trimesh.Scene()
|
114 |
+
pct = trimesh.PointCloud(rayst3r_points, colors=colors, radius=0.01)
|
115 |
+
scene.add_geometry(pct)
|
116 |
+
|
117 |
+
outfile = os.path.join(outdir, "rayst3r.glb")
|
118 |
+
scene.export(outfile)
|
119 |
+
return outfile
|
120 |
+
|
121 |
+
|
122 |
+
def input_to_glb(outdir,img,depth_dict,mask,rotated=False):
|
123 |
+
H, W = img.shape[:2]
|
124 |
+
intrinsics = depth_dict["intrinsics"].cpu().numpy()
|
125 |
+
intrinsics[0] *= W
|
126 |
+
intrinsics[1] *= H
|
127 |
+
|
128 |
+
depth = depth_dict["depth"].cpu().numpy()
|
129 |
+
cam2world = np.eye(4)
|
130 |
+
points_world = compute_pointmap(depth, cam2world, intrinsics)
|
131 |
+
|
132 |
+
scene = trimesh.Scene()
|
133 |
+
pts = np.concatenate([p[m] for p,m in zip(points_world,mask)])
|
134 |
+
col = np.concatenate([c[m] for c,m in zip(img,mask)])
|
135 |
+
|
136 |
+
pts = pts.reshape(-1,3)
|
137 |
+
pts[:,1] = -pts[:,1]
|
138 |
+
pts[:,2] = -pts[:,2]
|
139 |
+
|
140 |
+
|
141 |
+
pct = trimesh.PointCloud(pts, colors=col.reshape(-1,3))
|
142 |
+
scene.add_geometry(pct)
|
143 |
+
|
144 |
+
outfile = os.path.join(outdir, "input.glb")
|
145 |
+
scene.export(outfile)
|
146 |
+
return outfile
|
147 |
+
|
148 |
+
def depth_moge(input_img):
|
149 |
+
input_img_torch = torch.tensor(input_img / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
|
150 |
+
output = moge_model.infer(input_img_torch)
|
151 |
+
return output
|
152 |
+
|
153 |
+
def mask_rembg(input_img):
|
154 |
+
#masked_img = rembg.remove(input_img,)
|
155 |
+
output_img = rembg.remove(input_img, alpha_matting=False, post_process_mask=True)
|
156 |
+
|
157 |
+
# Convert to NumPy array
|
158 |
+
output_np = np.array(output_img)
|
159 |
+
alpha = output_np[..., 3]
|
160 |
+
|
161 |
+
# Step 2: Erode the alpha mask to shrink object slightly
|
162 |
+
kernel = np.ones((3, 3), np.uint8) # Adjust size for aggressiveness
|
163 |
+
eroded_alpha = cv2.erode(alpha, kernel, iterations=1)
|
164 |
+
# Step 3: Replace alpha channel
|
165 |
+
output_np[..., 3] = eroded_alpha
|
166 |
+
|
167 |
+
mask = output_np[:,:,-1] >= 128
|
168 |
+
rgb = output_np[:,:,:3]
|
169 |
+
return mask, rgb
|
170 |
+
|
171 |
+
def process_image(input_img):
|
172 |
+
# resize the input image
|
173 |
+
rotated = False
|
174 |
+
#if input_img.shape[0] > input_img.shape[1]:
|
175 |
+
#input_img = cv2.rotate(input_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
176 |
+
#rotated = True
|
177 |
+
input_img = cv2.resize(input_img, (640, 480))
|
178 |
+
mask, rgb = mask_rembg(input_img)
|
179 |
+
depth_dict = depth_moge(input_img)
|
180 |
+
|
181 |
+
if os.path.exists(outdir):
|
182 |
+
shutil.rmtree(outdir)
|
183 |
+
os.makedirs(outdir)
|
184 |
+
|
185 |
+
input_glb = input_to_glb(outdir,input_img,depth_dict,mask,rotated=rotated)
|
186 |
+
|
187 |
+
# visualize the input points in 3D in gradio
|
188 |
+
inference_glb = rayst3r_to_glb(input_img,depth_dict,mask,rotated=rotated)
|
189 |
+
|
190 |
+
return input_glb, inference_glb
|
191 |
+
|
192 |
+
demo = gr.Interface(
|
193 |
+
process_image,
|
194 |
+
gr.Image(),
|
195 |
+
[gr.Model3D(label="Input"), gr.Model3D(label="RaySt3R",)]
|
196 |
+
)
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
demo.launch()
|
datasets/generic_loader.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bb = breakpoint
|
2 |
+
import torch
|
3 |
+
import trimesh
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import pickle
|
9 |
+
import tqdm
|
10 |
+
import json
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
class GenericLoader(torch.utils.data.Dataset):
|
14 |
+
def __init__(self,dir="octmae_data/tiny_train/train_processed",seed=747,size=10,datasets=["fp_objaverse"],split="train",dtype=torch.float32,mode="slow",
|
15 |
+
prefetch_dino=False,dino_features=[23],view_select_mode="new_zoom",noise_std=0.0,rendered_views_mode="None",**kwargs):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.dir = dir
|
18 |
+
self.rng = np.random.default_rng(seed)
|
19 |
+
self.size = size
|
20 |
+
self.datasets = datasets
|
21 |
+
self.split = split
|
22 |
+
self.dtype = dtype
|
23 |
+
self.mode = mode
|
24 |
+
self.prefetch_dino = prefetch_dino
|
25 |
+
self.view_select_mode = view_select_mode
|
26 |
+
self.noise_std = noise_std * torch.iinfo(torch.uint16).max / 10.0 # variance in the range of the depth map, uint16 normalized to 10
|
27 |
+
if self.mode == 'slow':
|
28 |
+
self.prefetch_dino = True
|
29 |
+
self.find_scenes()
|
30 |
+
self.dino_features = dino_features
|
31 |
+
self.rendered_views_mode = rendered_views_mode
|
32 |
+
|
33 |
+
def find_dataset_location_list(self,dataset):
|
34 |
+
data_dir = None
|
35 |
+
for d in self.dir:
|
36 |
+
datasets = os.listdir(d)
|
37 |
+
if dataset in datasets:
|
38 |
+
if data_dir is not None:
|
39 |
+
raise ValueError(f"Dataset {dataset} found in multiple locations: {self.dir}")
|
40 |
+
else:
|
41 |
+
data_dir = os.path.join(d,dataset)
|
42 |
+
if data_dir is None:
|
43 |
+
raise ValueError(f"Dataset {dataset} not found in {self.dir}")
|
44 |
+
return data_dir
|
45 |
+
|
46 |
+
def find_dataset_location(self,dataset):
|
47 |
+
if isinstance(self.dir,list):
|
48 |
+
data_dir = self.find_dataset_location_list(dataset)
|
49 |
+
else:
|
50 |
+
data_dir = os.path.join(self.dir,dataset)
|
51 |
+
if not os.path.exists(data_dir):
|
52 |
+
raise ValueError(f"Dataset {dataset} not found in {self.dir}")
|
53 |
+
return data_dir
|
54 |
+
|
55 |
+
def find_scenes(self):
|
56 |
+
all_scenes = {}
|
57 |
+
print("Loading scenes...")
|
58 |
+
for dataset in self.datasets:
|
59 |
+
dataset_dir = self.find_dataset_location(dataset)
|
60 |
+
scenes = json.load(open(os.path.join(dataset_dir, f"{self.split}_scenes.json")))
|
61 |
+
scene_ids = [dataset + "_" + f.split("/")[-2] + "_" + f.split("/")[-1] for f in scenes]
|
62 |
+
all_scenes.update(dict(zip(scene_ids, scenes)))
|
63 |
+
self.scenes = all_scenes
|
64 |
+
self.scene_ids = list(self.scenes.keys())
|
65 |
+
# shuffle the scene ids
|
66 |
+
self.rng.shuffle(self.scene_ids)
|
67 |
+
if self.size > 0:
|
68 |
+
self.scene_ids = self.scene_ids[:self.size]
|
69 |
+
self.size = len(self.scene_ids)
|
70 |
+
return scenes
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return self.size
|
74 |
+
|
75 |
+
def decide_context_view(self,cam_dir):
|
76 |
+
# we pick the view furthest away from the origin as the view for conditioning
|
77 |
+
cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and not d.startswith("gen")] # input cam needs rgb
|
78 |
+
|
79 |
+
extrinsics = {c:torch.load(os.path.join(cam_dir,c,'cam2world.pt'),map_location='cpu',weights_only=True) for c in cam_dirs}
|
80 |
+
dist_origin = {c:torch.linalg.norm(extrinsics[c][:3,3]) for c in extrinsics}
|
81 |
+
|
82 |
+
if self.view_select_mode == 'new_zoom':
|
83 |
+
# find the view with the maximum distance to the origin
|
84 |
+
input_cam = max(dist_origin,key=dist_origin.get)
|
85 |
+
# pick another random view to predict, excluding the context view
|
86 |
+
elif self.view_select_mode == 'random':
|
87 |
+
# pick a random view
|
88 |
+
input_cam = str(self.rng.choice(list(dist_origin.keys())))
|
89 |
+
# pick another random view to predict, excluding the context view
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Invalid mode: {self.view_select_mode}")
|
92 |
+
|
93 |
+
if self.rendered_views_mode == "None":
|
94 |
+
pass
|
95 |
+
elif self.rendered_views_mode == "random":
|
96 |
+
cam_dirs = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d))]
|
97 |
+
elif self.rendered_views_mode == "always":
|
98 |
+
cam_dirs_gen = [d for d in os.listdir(cam_dir) if os.path.isdir(os.path.join(cam_dir,d)) and d.startswith("gen")]
|
99 |
+
if len(cam_dirs_gen) > 0:
|
100 |
+
cam_dirs = cam_dirs_gen
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Invalid mode: {self.rendered_views_mode}")
|
103 |
+
|
104 |
+
possible_views = [v for v in cam_dirs if v != input_cam]
|
105 |
+
new_cam = str(self.rng.choice(possible_views))
|
106 |
+
return input_cam,new_cam
|
107 |
+
|
108 |
+
def transform_pointmap(self,pointmap_cam,c2w):
|
109 |
+
# pointmap: shape H x W x 3
|
110 |
+
# cw2: shape 4 x 4
|
111 |
+
# we want to transform the pointmap to the world frame
|
112 |
+
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
|
113 |
+
pointmap_world_h = pointmap_cam_h @ c2w.T
|
114 |
+
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
|
115 |
+
return pointmap_world
|
116 |
+
|
117 |
+
def load_scene_slow(self,input_cam,new_cam,cam_dir):
|
118 |
+
|
119 |
+
data = dict(new_cams={},input_cams={})
|
120 |
+
|
121 |
+
data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
122 |
+
data['new_cams']['depths'] = [torch.load(os.path.join(cam_dir,new_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
123 |
+
data['new_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,new_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['new_cams']['c2ws'][0])]
|
124 |
+
data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
125 |
+
data['new_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,new_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)]
|
126 |
+
|
127 |
+
# add the context views
|
128 |
+
data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
129 |
+
data['input_cams']['depths'] = [torch.load(os.path.join(cam_dir,input_cam,'depth.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
130 |
+
data['input_cams']['pointmaps'] = [self.transform_pointmap(torch.load(os.path.join(cam_dir,input_cam,'pointmap.pt'),map_location='cpu',weights_only=True).to(self.dtype),data['input_cams']['c2ws'][0])]
|
131 |
+
data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
132 |
+
data['input_cams']['valid_masks'] = [torch.load(os.path.join(cam_dir,input_cam,'mask.pt'),map_location='cpu',weights_only=True).to(torch.bool)]
|
133 |
+
data['input_cams']['imgs'] = [torch.load(os.path.join(cam_dir,input_cam,'rgb.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
134 |
+
data['input_cams']['dino_features'] = [torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features]
|
135 |
+
return data
|
136 |
+
|
137 |
+
def depth_to_metric(self,depth):
|
138 |
+
# depth: shape H x W
|
139 |
+
# we want to convert the depth to a metric depth
|
140 |
+
depth_max = 10.0
|
141 |
+
depth_scaled = depth_max * (depth / 65535.0)
|
142 |
+
return depth_scaled
|
143 |
+
|
144 |
+
def load_scene_fast(self,input_cam,new_cam,cam_dir):
|
145 |
+
data = dict(new_cams={},input_cams={})
|
146 |
+
data['new_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,new_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
147 |
+
data['new_cams']['Ks'] = [torch.load(os.path.join(cam_dir,new_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
148 |
+
data['new_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'depth.png'))).astype(np.float32))]
|
149 |
+
data['new_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,new_cam,'mask.png'))))]
|
150 |
+
|
151 |
+
data['input_cams']['c2ws'] = [torch.load(os.path.join(cam_dir,input_cam,'cam2world.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
152 |
+
data['input_cams']['Ks'] = [torch.load(os.path.join(cam_dir,input_cam,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
153 |
+
data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'depth.png'))).astype(np.float32))]
|
154 |
+
data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'mask.png'))))]
|
155 |
+
data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(cam_dir,input_cam,'rgb.png'))))]
|
156 |
+
if self.prefetch_dino:
|
157 |
+
data['input_cams']['dino_features'] = [torch.cat([torch.load(os.path.join(cam_dir,input_cam,f'dino_features_layer_{l}.pt'),map_location='cpu',weights_only=True).to(self.dtype) for l in self.dino_features],dim=-1)]
|
158 |
+
return data
|
159 |
+
|
160 |
+
def __getitem__(self,idx):
|
161 |
+
cam_dir = os.path.join(self.scenes[self.scene_ids[idx]],'cameras')
|
162 |
+
#data['input_cams'] = {k:[v[0].unsqueeze(0)] for k,v in data['input_cams'].items()}
|
163 |
+
input_cam,new_cam = self.decide_context_view(cam_dir)
|
164 |
+
if self.mode == 'slow':
|
165 |
+
data = self.load_scene_slow(input_cam,new_cam,cam_dir)
|
166 |
+
else:
|
167 |
+
data = self.load_scene_fast(input_cam,new_cam,cam_dir)
|
168 |
+
return data
|
engine.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bb=breakpoint
|
2 |
+
import torch
|
3 |
+
from utils.geometry import center_pointmaps, uncenter_pointmaps
|
4 |
+
from utils.utils import scenes_to_batch, batch_to_scenes
|
5 |
+
from utils.batch_prep import prepare_fast_batch, normalize_batch, denormalize_batch
|
6 |
+
from utils.viz import save_pointmaps
|
7 |
+
from tqdm import tqdm
|
8 |
+
import wandb
|
9 |
+
from utils import misc
|
10 |
+
from torch.amp import GradScaler
|
11 |
+
from utils.eval import eval_pred
|
12 |
+
from utils.geometry import depth2pts
|
13 |
+
|
14 |
+
def batch_to_device(batch,device='cuda'):
|
15 |
+
for key in batch:
|
16 |
+
if isinstance(batch[key],torch.Tensor):
|
17 |
+
batch[key] = batch[key].to(device)
|
18 |
+
elif isinstance(batch[key],dict):
|
19 |
+
batch[key] = batch_to_device(batch[key],device)
|
20 |
+
return batch
|
21 |
+
|
22 |
+
def eval_model(model,batch,mode='loss',device='cuda',dino_model=None,args=None,augmentor=None,return_scale=False):
|
23 |
+
batch = batch_to_device(batch,device)
|
24 |
+
# check if model is distributed
|
25 |
+
if isinstance(model,torch.nn.parallel.DistributedDataParallel):
|
26 |
+
dino_layers = model.module.dino_layers
|
27 |
+
else:
|
28 |
+
dino_layers = model.dino_layers
|
29 |
+
if 'pointmaps' not in list(batch['input_cams'].keys()):
|
30 |
+
batch = prepare_fast_batch(batch,dino_model,dino_layers)
|
31 |
+
|
32 |
+
normalize_mode = args.normalize_mode if args is not None else 'median'
|
33 |
+
batch, scale_factors = normalize_batch(batch,normalize_mode)
|
34 |
+
if augmentor is not None:
|
35 |
+
batch = augmentor(batch)
|
36 |
+
|
37 |
+
batch, n_cams = scenes_to_batch(batch)
|
38 |
+
batch = center_pointmaps(batch) # centering around first camera
|
39 |
+
|
40 |
+
device = args.device if args is not None else 'cuda'
|
41 |
+
with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
|
42 |
+
pred, gt, loss_dict = model(batch,mode='viz')
|
43 |
+
|
44 |
+
if 'pointmaps' not in list(pred.keys()):
|
45 |
+
pred['pointmaps'] = depth2pts(pred['depths'].squeeze(-1),batch['new_cams']['Ks'])
|
46 |
+
elif 'depths' not in list(pred.keys()):
|
47 |
+
pred['depths'] = pred['pointmaps'][...,-1]
|
48 |
+
loss_dict = {**loss_dict,**eval_pred(pred, gt)}
|
49 |
+
if mode == 'loss':
|
50 |
+
return loss_dict
|
51 |
+
elif mode == 'viz':
|
52 |
+
pred, gt, batch = uncenter_pointmaps(pred, gt, batch)
|
53 |
+
pred, gt, batch = batch_to_scenes(pred, gt,batch, n_cams)
|
54 |
+
if return_scale:
|
55 |
+
return pred, gt, loss_dict, scale_factors[0].item()
|
56 |
+
else:
|
57 |
+
return pred, gt, loss_dict
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Invalid mode: {mode}")
|
60 |
+
|
61 |
+
def update_loss_dict(loss_dict,loss_dict_new,sample_count):
|
62 |
+
for key in loss_dict_new:
|
63 |
+
if key not in loss_dict:
|
64 |
+
loss_dict[key] = loss_dict_new[key]
|
65 |
+
else:
|
66 |
+
# previously stored value in loss_dict is average from sample_count samples
|
67 |
+
# so we need to update it to include the new sample
|
68 |
+
loss_dict[key] = (loss_dict[key] * sample_count + loss_dict_new[key]) / (sample_count + 1)
|
69 |
+
return loss_dict
|
70 |
+
|
71 |
+
def train_epoch(model, train_loader, optimizer, device='cuda', max_norm=1.0,log_wandb=False,epoch=0,batch_size=None,args=None,dino_model=None,augmentor=None):
|
72 |
+
model.train()
|
73 |
+
all_losses_dict = {}
|
74 |
+
|
75 |
+
sample_idx = epoch * batch_size * len(train_loader)
|
76 |
+
scaler = GradScaler()
|
77 |
+
for i, batch in tqdm(enumerate(train_loader),total=len(train_loader)):
|
78 |
+
optimizer.zero_grad()
|
79 |
+
new_loss_dict = eval_model(model, batch, mode='loss', device=device,dino_model=dino_model,args=args,augmentor=augmentor)
|
80 |
+
loss = new_loss_dict['loss']
|
81 |
+
if loss is None:
|
82 |
+
continue
|
83 |
+
|
84 |
+
scaler.scale(loss).backward()
|
85 |
+
# Unscales the gradients of optimizer's assigned params in-place
|
86 |
+
scaler.unscale_(optimizer)
|
87 |
+
|
88 |
+
grad_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in model.parameters() if p.grad is not None]))
|
89 |
+
if grad_norm.isnan():
|
90 |
+
breakpoint()
|
91 |
+
|
92 |
+
## Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
93 |
+
if max_norm > 0:
|
94 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
95 |
+
|
96 |
+
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
97 |
+
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
98 |
+
scaler.step(optimizer)
|
99 |
+
|
100 |
+
# Updates the scale for next iteration.
|
101 |
+
scaler.update()
|
102 |
+
|
103 |
+
new_loss_dict['grad_norm'] = grad_norm.detach().cpu().item()
|
104 |
+
|
105 |
+
misc.adjust_learning_rate(optimizer, epoch + i/len(train_loader), args)
|
106 |
+
optimizer.step()
|
107 |
+
|
108 |
+
new_loss_dict = {k: (v.detach().cpu().item() if isinstance(v, torch.Tensor) else v) for k, v in new_loss_dict.items()}
|
109 |
+
if log_wandb:
|
110 |
+
wandb_dict = {f"train_{k}":v for k,v in new_loss_dict.items()}
|
111 |
+
wandb.log(wandb_dict, step=sample_idx + (i+1)*batch_size)
|
112 |
+
# log learning rate
|
113 |
+
wandb.log({"train_lr": optimizer.param_groups[0]['lr']}, step=sample_idx + (i+1)*batch_size)
|
114 |
+
|
115 |
+
all_losses_dict = update_loss_dict(all_losses_dict, new_loss_dict,sample_count=i)
|
116 |
+
# Clear cache and delete variables to free memory
|
117 |
+
torch.cuda.empty_cache()
|
118 |
+
del loss
|
119 |
+
del new_loss_dict
|
120 |
+
del grad_norm
|
121 |
+
del batch
|
122 |
+
|
123 |
+
return all_losses_dict
|
124 |
+
|
125 |
+
def eval_epoch(model,test_loader,device='cuda',dino_model=None,args=None,augmentor=None):
|
126 |
+
model.eval()
|
127 |
+
all_losses_dict = {}
|
128 |
+
with torch.no_grad():
|
129 |
+
for i, batch in tqdm(enumerate(test_loader),total=len(test_loader)):
|
130 |
+
new_loss_dict = eval_model(model,batch,mode='loss',device=device,dino_model=dino_model,args=args,augmentor=augmentor)
|
131 |
+
if new_loss_dict is None:
|
132 |
+
continue
|
133 |
+
all_losses_dict = update_loss_dict(all_losses_dict,new_loss_dict,sample_count=i)
|
134 |
+
|
135 |
+
torch.cuda.empty_cache()
|
136 |
+
del new_loss_dict
|
137 |
+
del batch
|
138 |
+
|
139 |
+
return all_losses_dict
|
eval_wrapper/eval.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchvision import transforms
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import open3d as o3d
|
9 |
+
current_dir = os.getcwd()
|
10 |
+
sys.path.append(current_dir)
|
11 |
+
|
12 |
+
from eval_wrapper.sample_poses import pointmap_to_poses
|
13 |
+
from utils.fusion import fuse_batch
|
14 |
+
from models.rayquery import *
|
15 |
+
from models.losses import *
|
16 |
+
import argparse
|
17 |
+
from utils import misc
|
18 |
+
import torch.distributed as dist
|
19 |
+
from utils.collate import collate
|
20 |
+
from engine import eval_model
|
21 |
+
from utils.viz import just_load_viz
|
22 |
+
from utils.geometry import compute_pointmap_torch
|
23 |
+
from eval_wrapper.eval_utils import npy2ply, filter_all_masks
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
class EvalWrapper(torch.nn.Module):
|
27 |
+
def __init__(self,checkpoint_path,distributed=False,device="cuda",dtype=torch.float32,**kwargs):
|
28 |
+
super().__init__()
|
29 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
30 |
+
model_string = checkpoint['args'].model
|
31 |
+
|
32 |
+
self.model = eval(model_string).to(device)
|
33 |
+
if distributed:
|
34 |
+
rank, world_size, local_rank = misc.setup_distributed()
|
35 |
+
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank],find_unused_parameters=True)
|
36 |
+
|
37 |
+
self.dtype = dtype
|
38 |
+
self.model.load_state_dict(checkpoint['model'])
|
39 |
+
self.model.eval()
|
40 |
+
|
41 |
+
def forward(self,x,dino_model=None):
|
42 |
+
pred, gt, loss, scale = eval_model(self.model,x,mode='viz',dino_model=dino_model,return_scale=True)
|
43 |
+
return pred, gt, loss, scale
|
44 |
+
|
45 |
+
class PostProcessWrapper(torch.nn.Module):
|
46 |
+
def __init__(self,pred_mask_threshold = 0.5, mode='novel_views',
|
47 |
+
debug=False,conf_dist_mode='isotonic',set_conf=None,percentile=20,
|
48 |
+
no_input_mask=False,no_pred_mask=False):
|
49 |
+
super().__init__()
|
50 |
+
self.pred_mask_threshold = pred_mask_threshold
|
51 |
+
self.mode = mode
|
52 |
+
self.debug = debug
|
53 |
+
self.conf_dist_mode = conf_dist_mode
|
54 |
+
self.set_conf = set_conf
|
55 |
+
self.percentile = percentile
|
56 |
+
self.no_input_mask = no_input_mask
|
57 |
+
self.no_pred_mask = no_pred_mask
|
58 |
+
|
59 |
+
def transform_pointmap(self,pointmap_cam,c2w):
|
60 |
+
# pointmap: shape H x W x 3
|
61 |
+
# cw2: shape 4 x 4
|
62 |
+
# we want to transform the pointmap to the world frame
|
63 |
+
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
|
64 |
+
pointmap_world_h = pointmap_cam_h @ c2w.T
|
65 |
+
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
|
66 |
+
return pointmap_world
|
67 |
+
|
68 |
+
def reject_conf_points(self,conf_pts):
|
69 |
+
if self.set_conf is None:
|
70 |
+
raise ValueError("set_conf must be set")
|
71 |
+
|
72 |
+
conf_mask = conf_pts > self.set_conf
|
73 |
+
return conf_mask
|
74 |
+
|
75 |
+
|
76 |
+
def project_input_mask(self,pred_dict,batch):
|
77 |
+
input_mask = batch['input_cams']['original_valid_masks'][0][0] # shape H x W
|
78 |
+
input_c2w = batch['input_cams']['c2ws'][0][0]
|
79 |
+
input_w2c = torch.linalg.inv(input_c2w)
|
80 |
+
input_K = batch['input_cams']['Ks'][0][0]
|
81 |
+
H, W = input_mask.shape
|
82 |
+
pointmaps_input_cam = torch.stack([self.transform_pointmap(pmap,input_w2c@c2w) for pmap,c2w in zip(pred_dict['pointmaps'][0],batch['new_cams']['c2ws'][0])]) # bp: Assuming batch size is 1!!
|
83 |
+
img_coords = pointmaps_input_cam @ input_K.T
|
84 |
+
img_coords = (img_coords[...,:2]/img_coords[...,2:3]).int()
|
85 |
+
|
86 |
+
n_views, H, W = img_coords.shape[:3]
|
87 |
+
device = input_mask.device
|
88 |
+
if self.no_input_mask:
|
89 |
+
combined_mask = torch.ones((n_views, H, W), device=device)
|
90 |
+
else:
|
91 |
+
combined_mask = torch.zeros((n_views, H, W), device=device)
|
92 |
+
|
93 |
+
# Flatten spatial dims
|
94 |
+
xs = img_coords[..., 0].view(n_views, -1) # [V, H*W]
|
95 |
+
ys = img_coords[..., 1].view(n_views, -1) # [V, H*W]
|
96 |
+
|
97 |
+
# Create base pixel coords (i, j)
|
98 |
+
i_coords = torch.arange(H, device=device).view(-1, 1).expand(H, W).reshape(-1) # [H*W]
|
99 |
+
j_coords = torch.arange(W, device=device).view(1, -1).expand(H, W).reshape(-1) # [H*W]
|
100 |
+
mask_coords = torch.stack((i_coords, j_coords), dim=-1) # [H*W, 2], shared across views
|
101 |
+
|
102 |
+
# Mask for valid projections
|
103 |
+
valid = (xs >= 0) & (xs < W) & (ys >= 0) & (ys < H) # [V, H*W]
|
104 |
+
|
105 |
+
# Clip out-of-bounds coords for indexing (only valid will be used anyway)
|
106 |
+
xs_clipped = torch.clamp(xs, 0, W-1)
|
107 |
+
ys_clipped = torch.clamp(ys, 0, H-1)
|
108 |
+
|
109 |
+
# input_mask lookup per view
|
110 |
+
flat_input_mask = input_mask[ys_clipped, xs_clipped] # [V, H*W]
|
111 |
+
input_mask_mask = flat_input_mask & valid # apply valid range mask
|
112 |
+
|
113 |
+
# Apply mask to coords and depths
|
114 |
+
depth_points = pointmaps_input_cam[..., -1].view(n_views, -1) # [V, H*W]
|
115 |
+
input_depths = batch['input_cams']['depths'][0][0][ys_clipped, xs_clipped] # [V, H*W]
|
116 |
+
|
117 |
+
depth_mask = (depth_points > input_depths) & input_mask_mask # final mask [V, H*W]
|
118 |
+
#depth_mask = input_mask_mask # final mask [V, H*W]
|
119 |
+
|
120 |
+
# Get final (i,j) coords to write
|
121 |
+
final_i = mask_coords[:, 0].unsqueeze(0).expand(n_views, -1)[depth_mask] # [N_mask]
|
122 |
+
final_j = mask_coords[:, 1].unsqueeze(0).expand(n_views, -1)[depth_mask] # [N_mask]
|
123 |
+
final_view_idx = torch.arange(n_views, device=device).view(-1, 1).expand(-1, H*W)[depth_mask] # [N_mask]
|
124 |
+
|
125 |
+
# Scatter final mask
|
126 |
+
combined_mask[final_view_idx, final_i, final_j] = 1
|
127 |
+
return combined_mask.unsqueeze(0).bool()
|
128 |
+
|
129 |
+
def forward(self,pred_dict,batch):
|
130 |
+
if self.mode == 'novel_views':
|
131 |
+
project_masks = self.project_input_mask(pred_dict,batch)
|
132 |
+
pred_mask_raw = torch.sigmoid(pred_dict['classifier'])
|
133 |
+
if self.no_pred_mask:
|
134 |
+
pred_masks = torch.ones_like(project_masks).bool()
|
135 |
+
else:
|
136 |
+
pred_masks = (pred_mask_raw > self.pred_mask_threshold).bool()
|
137 |
+
|
138 |
+
conf_masks = self.reject_conf_points(pred_dict['conf_pointmaps'])
|
139 |
+
combined_mask = project_masks & pred_masks & conf_masks
|
140 |
+
batch['new_cams']['valid_masks'] = combined_mask
|
141 |
+
|
142 |
+
elif self.mode == 'input_view':
|
143 |
+
conf_masks = self.reject_conf_points(pred_dict['conf_pointmaps'])
|
144 |
+
if self.no_pred_mask:
|
145 |
+
pred_masks = torch.ones_like(conf_masks).bool()
|
146 |
+
else:
|
147 |
+
pred_mask_raw = torch.sigmoid(pred_dict['classifier'])
|
148 |
+
pred_masks = (pred_mask_raw > self.pred_mask_threshold).bool()
|
149 |
+
combined_mask = conf_masks & batch['new_cams']['valid_masks'] & pred_masks
|
150 |
+
batch['new_cams']['valid_masks'] = combined_mask # this is for visualization
|
151 |
+
|
152 |
+
return pred_dict, batch
|
153 |
+
|
154 |
+
class GenericLoaderSmall(torch.utils.data.Dataset):
|
155 |
+
def __init__(self,data_dir,mode="single_scene",dtype=torch.float32,n_pred_views=3,pred_input_only=False,min_depth=0.1,
|
156 |
+
pointmap_for_bb=None,run_octmae=False,false_positive=None,false_negative=None):
|
157 |
+
self.data_dir = data_dir
|
158 |
+
self.mode = mode
|
159 |
+
self.dtype = dtype
|
160 |
+
self.rng = np.random.RandomState(seed=42)
|
161 |
+
self.n_pred_views = n_pred_views
|
162 |
+
self.min_depth = self.depth_metric_to_uint16(min_depth)
|
163 |
+
if self.mode == "single_scene":
|
164 |
+
self.inputs = [data_dir]
|
165 |
+
self.pred_input_only = pred_input_only
|
166 |
+
if self.pred_input_only:
|
167 |
+
self.n_pred_views = 1
|
168 |
+
self.desired_resolution = (480,640)
|
169 |
+
self.resize_transform_rgb = transforms.Resize(self.desired_resolution)
|
170 |
+
self.resize_transform_depth = transforms.Resize(self.desired_resolution,interpolation=transforms.InterpolationMode.NEAREST)
|
171 |
+
self.pointmap_for_bb = pointmap_for_bb
|
172 |
+
self.run_octmae = run_octmae
|
173 |
+
self.false_positive = false_positive
|
174 |
+
self.false_negative = false_negative
|
175 |
+
|
176 |
+
def transform_pointmap(self,pointmap_cam,c2w):
|
177 |
+
# pointmap: shape H x W x 3
|
178 |
+
# cw2: shape 4 x 4
|
179 |
+
# we want to transform the pointmap to the world frame
|
180 |
+
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
|
181 |
+
pointmap_world_h = pointmap_cam_h @ c2w.T
|
182 |
+
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
|
183 |
+
return pointmap_world
|
184 |
+
|
185 |
+
def __len__(self):
|
186 |
+
return len(self.inputs)
|
187 |
+
|
188 |
+
def look_at(self,cam_pos, center=(0,0,0), up=(0,0,1)):
|
189 |
+
z = center - cam_pos
|
190 |
+
z /= np.linalg.norm(z, axis=-1, keepdims=True)
|
191 |
+
y = -np.float32(up)
|
192 |
+
y = y - np.sum(y * z, axis=-1, keepdims=True) * z
|
193 |
+
y /= np.linalg.norm(y, axis=-1, keepdims=True)
|
194 |
+
x = np.cross(y, z, axis=-1)
|
195 |
+
|
196 |
+
cam2w = np.r_[np.c_[x,y,z,cam_pos],[[0,0,0,1]]]
|
197 |
+
return cam2w.astype(np.float32)
|
198 |
+
|
199 |
+
def find_new_views(self,n_views,geometric_median = (0,0,0),r_min=0.4,r_max=0.9):
|
200 |
+
rad = self.rng.uniform(r_min,r_max, size=n_views)
|
201 |
+
azi = self.rng.uniform(0, 2*np.pi, size=n_views)
|
202 |
+
ele = self.rng.uniform(-np.pi, np.pi, size=n_views)
|
203 |
+
cam_centers = np.c_[np.cos(azi), np.sin(azi)]
|
204 |
+
cam_centers = rad[:,None] * np.c_[np.cos(ele)[:,None]*cam_centers, np.sin(ele)] + geometric_median
|
205 |
+
|
206 |
+
c2ws = [self.look_at(cam_pos=cam_center,center=geometric_median) for cam_center in cam_centers]
|
207 |
+
return c2ws
|
208 |
+
|
209 |
+
def depth_uint16_to_metric(self,depth):
|
210 |
+
return depth / torch.iinfo(torch.uint16).max * 10.0 # threshold is in m, convert to uint16 value
|
211 |
+
|
212 |
+
def depth_metric_to_uint16(self,depth):
|
213 |
+
return depth * torch.iinfo(torch.uint16).max / 10.0 # threshold is in m, convert to uint16 value
|
214 |
+
|
215 |
+
def resize(self,depth,img,mask,K):
|
216 |
+
s_x = self.desired_resolution[1] / img.shape[1]
|
217 |
+
s_y = self.desired_resolution[0] / img.shape[0]
|
218 |
+
depth = self.resize_transform_depth(depth.unsqueeze(0)).squeeze(0)
|
219 |
+
img = self.resize_transform_rgb(img.permute(-1,0,1)).permute(1,2,0)
|
220 |
+
mask = self.resize_transform_depth(mask.unsqueeze(0)).squeeze(0)
|
221 |
+
K[0] *= s_x
|
222 |
+
K[1] *= s_y
|
223 |
+
return depth, img, mask, K
|
224 |
+
|
225 |
+
def add_false_positives_and_negatives(self,valid_mask,false_positive,false_negative):
|
226 |
+
# add false positives to the valid mask
|
227 |
+
# add false negatives to the valid mask
|
228 |
+
# return the new valid mask
|
229 |
+
n_total_pixels = valid_mask.sum()
|
230 |
+
n_pixels_left = n_total_pixels * (1-false_positive)
|
231 |
+
|
232 |
+
mask_pixels_coords = torch.where(valid_mask)
|
233 |
+
left_pixels_coords = torch.where(~valid_mask)
|
234 |
+
|
235 |
+
# false positives
|
236 |
+
n_false_positives = min(int(n_pixels_left * false_positive),n_pixels_left)
|
237 |
+
# randomly sample n_false_positives from mask_pixels_coords
|
238 |
+
false_positives = torch.randperm(len(left_pixels_coords[0]))[:n_false_positives]
|
239 |
+
valid_mask[left_pixels_coords[0][false_positives],left_pixels_coords[1][false_positives]] = 1
|
240 |
+
|
241 |
+
# false negatives
|
242 |
+
n_false_negatives = min(int(n_total_pixels * false_negative),n_total_pixels)
|
243 |
+
# randomly sample n_false_negatives from left_pixels_coords
|
244 |
+
false_negatives = torch.randperm(len(mask_pixels_coords[0]))[:n_false_negatives]
|
245 |
+
valid_mask[mask_pixels_coords[0][false_negatives],mask_pixels_coords[1][false_negatives]] = 0
|
246 |
+
|
247 |
+
return valid_mask
|
248 |
+
|
249 |
+
def __getitem__(self,idx):
|
250 |
+
scene_dir = self.inputs[idx]
|
251 |
+
|
252 |
+
data = dict(new_cams={},input_cams={})
|
253 |
+
|
254 |
+
c2w_path = os.path.join(scene_dir,'cam2world.pt')
|
255 |
+
if os.path.exists(c2w_path):
|
256 |
+
data['input_cams']['c2ws_original'] = [torch.load(c2w_path,map_location='cpu',weights_only=True).to(self.dtype)]
|
257 |
+
else:
|
258 |
+
data['input_cams']['c2ws_original'] = [torch.eye(4).to(self.dtype)]
|
259 |
+
|
260 |
+
data['input_cams']['c2ws'] = [torch.eye(4).to(self.dtype)]
|
261 |
+
data['input_cams']['Ks'] = [torch.load(os.path.join(scene_dir,'intrinsics.pt'),map_location='cpu',weights_only=True).to(self.dtype)]
|
262 |
+
data['input_cams']['depths'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'depth.png'))).astype(np.float32))]
|
263 |
+
data['input_cams']['valid_masks'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'mask.png')))).bool()]
|
264 |
+
data['input_cams']['imgs'] = [torch.from_numpy(np.array(Image.open(os.path.join(scene_dir,'rgb.png'))))]
|
265 |
+
|
266 |
+
if self.false_positive is not None or self.false_negative is not None:
|
267 |
+
data['input_cams']['valid_masks'][0] = self.add_false_positives_and_negatives(data['input_cams']['valid_masks'][0],self.false_positive,self.false_negative)
|
268 |
+
|
269 |
+
if data['input_cams']['depths'][0].shape != self.desired_resolution:
|
270 |
+
data['input_cams']['depths'][0], data['input_cams']['imgs'][0], data['input_cams']['valid_masks'][0], data['input_cams']['Ks'][0] = \
|
271 |
+
self.resize(data['input_cams']['depths'][0], data['input_cams']['imgs'][0], data['input_cams']['valid_masks'][0], data['input_cams']['Ks'][0])
|
272 |
+
|
273 |
+
data['input_cams']['original_valid_masks'] = [data['input_cams']['valid_masks'][0].clone()]
|
274 |
+
data['input_cams']['valid_masks'][0] = data['input_cams']['valid_masks'][0] & \
|
275 |
+
(data['input_cams']['depths'][0] > self.min_depth)
|
276 |
+
|
277 |
+
if self.pred_input_only:
|
278 |
+
c2ws = [data['input_cams']['c2ws'][0].cpu().numpy()]
|
279 |
+
else:
|
280 |
+
input_mask = data['input_cams']['valid_masks'][0]
|
281 |
+
if self.pointmap_for_bb is not None:
|
282 |
+
pointmap_input = self.pointmap_for_bb
|
283 |
+
else:
|
284 |
+
pointmap_input = compute_pointmap_torch(self.depth_uint16_to_metric(data['input_cams']['depths'][0]),data['input_cams']['c2ws'][0],data['input_cams']['Ks'][0],device='cpu')[input_mask]
|
285 |
+
c2ws = pointmap_to_poses(pointmap_input, self.n_pred_views, inner_radius=1.1, outer_radius=2.5, device='cpu',run_octmae=self.run_octmae)
|
286 |
+
self.n_pred_views = len(c2ws)
|
287 |
+
|
288 |
+
data['new_cams'] = {}
|
289 |
+
data['new_cams']['c2ws'] = [torch.from_numpy(c2w).to(self.dtype) for c2w in c2ws]
|
290 |
+
data['new_cams']['depths'] = [torch.zeros_like(data['input_cams']['depths'][0]) for _ in range(self.n_pred_views)]
|
291 |
+
data['new_cams']['Ks'] = [data['input_cams']['Ks'][0] for _ in range(self.n_pred_views)]
|
292 |
+
if self.pred_input_only:
|
293 |
+
data['new_cams']['valid_masks'] = data['input_cams']['original_valid_masks']
|
294 |
+
else:
|
295 |
+
data['new_cams']['valid_masks'] = [torch.ones_like(data['input_cams']['valid_masks'][0]) for _ in range(self.n_pred_views)]
|
296 |
+
|
297 |
+
return data
|
298 |
+
|
299 |
+
def dict_to_float(d):
|
300 |
+
return {k: v.float() for k, v in d.items()}
|
301 |
+
|
302 |
+
def merge_dicts(d1,d2):
|
303 |
+
# stack the tensors along dimension 1
|
304 |
+
for k,v in d1.items():
|
305 |
+
d1[k] = torch.cat([d1[k],d2[k]],dim=1)
|
306 |
+
return d1
|
307 |
+
|
308 |
+
def compute_all_points(pred_dict,batch):
|
309 |
+
n_views = pred_dict['depths'].shape[1]
|
310 |
+
all_points = None
|
311 |
+
for i in range(n_views):
|
312 |
+
mask = batch['new_cams']['valid_masks'][0,i]
|
313 |
+
pointmap = compute_pointmap_torch(pred_dict['depths'][0,i],batch['new_cams']['c2ws'][0,i],batch['new_cams']['Ks'][0,i])
|
314 |
+
masked_points = pointmap[mask]
|
315 |
+
if all_points is None:
|
316 |
+
all_points = masked_points
|
317 |
+
else:
|
318 |
+
all_points = torch.cat([all_points,masked_points],dim=0)
|
319 |
+
return all_points
|
320 |
+
|
321 |
+
def eval_scene(model, data_dir,visualize=False,rr_addr=None,run_octmae=False,set_conf=5,
|
322 |
+
no_input_mask=False,no_pred_mask=False,no_filter_input_view=False,false_positive=None,false_negative=None,n_pred_views=5,
|
323 |
+
do_filter_all_masks=False, dino_model=None,tsdf=False):
|
324 |
+
|
325 |
+
if dino_model is None:
|
326 |
+
# Loading DINOv2 model
|
327 |
+
dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
|
328 |
+
dino_model.eval()
|
329 |
+
dino_model.to("cuda")
|
330 |
+
|
331 |
+
dataloader_input_view = GenericLoaderSmall(data_dir,n_pred_views=1,pred_input_only=True,false_positive=false_positive,false_negative=false_negative)
|
332 |
+
input_view_loader = DataLoader(dataloader_input_view, batch_size=1, shuffle=True, collate_fn=collate)
|
333 |
+
input_view_batch = next(iter(input_view_loader))
|
334 |
+
|
335 |
+
postprocessor_input_view = PostProcessWrapper(mode='input_view',set_conf=set_conf,
|
336 |
+
no_input_mask=no_input_mask,no_pred_mask=no_pred_mask)
|
337 |
+
postprocessor_pred_views = PostProcessWrapper(mode='novel_views',debug=False,set_conf=set_conf,
|
338 |
+
no_input_mask=no_input_mask,no_pred_mask=no_pred_mask)
|
339 |
+
fused_meshes = None
|
340 |
+
with torch.no_grad():
|
341 |
+
pred_input_view, gt_input_view, _, scale_factor = model(input_view_batch,dino_model)
|
342 |
+
if no_filter_input_view:
|
343 |
+
pred_input_view['pointmaps'] = input_view_batch['input_cams']['pointmaps']
|
344 |
+
pred_input_view['depths'] = input_view_batch['input_cams']['depths']
|
345 |
+
else:
|
346 |
+
pred_input_view, input_view_batch = postprocessor_input_view(pred_input_view,input_view_batch)
|
347 |
+
|
348 |
+
input_points = pred_input_view['pointmaps'][0][0][input_view_batch['new_cams']['valid_masks'][0][0]] * (1.0/scale_factor)
|
349 |
+
if input_points.shape[0] == 0:
|
350 |
+
input_points = None
|
351 |
+
|
352 |
+
dataloader_pred_views = GenericLoaderSmall(data_dir,n_pred_views=n_pred_views,pred_input_only=False,
|
353 |
+
pointmap_for_bb=input_points,run_octmae=run_octmae)
|
354 |
+
pred_views_loader = DataLoader(dataloader_pred_views, batch_size=1, shuffle=True, collate_fn=collate)
|
355 |
+
pred_views_batch = next(iter(pred_views_loader))
|
356 |
+
|
357 |
+
# this is for the mask ablation
|
358 |
+
if (false_positive is not None or false_negative is not None) and input_points is not None:
|
359 |
+
pred_views_batch['input_cams']['valid_masks'] = input_view_batch['input_cams']['valid_masks']
|
360 |
+
|
361 |
+
pred_new_views, gt_new_views, _, scale_factor = model(pred_views_batch,dino_model)
|
362 |
+
pred_new_views, pred_views_batch = postprocessor_pred_views(pred_new_views,pred_views_batch)
|
363 |
+
|
364 |
+
pred = merge_dicts(dict_to_float(pred_input_view),dict_to_float(pred_new_views))
|
365 |
+
gt = merge_dicts(dict_to_float(gt_input_view),dict_to_float(gt_new_views))
|
366 |
+
|
367 |
+
batch = copy.deepcopy(input_view_batch)
|
368 |
+
batch['new_cams'] = merge_dicts(input_view_batch['new_cams'],pred_views_batch['new_cams'])
|
369 |
+
gt['pointmaps'] = None # make sure it's not used in viz
|
370 |
+
|
371 |
+
if do_filter_all_masks:
|
372 |
+
batch = filter_all_masks(pred,input_view_batch,max_outlier_views=1)
|
373 |
+
|
374 |
+
# scale factor is the scale we applied to the input view for inference
|
375 |
+
all_points = compute_all_points(pred,batch)
|
376 |
+
all_points = all_points*(1.0/scale_factor)
|
377 |
+
|
378 |
+
# transform all_points to the original coordinate system
|
379 |
+
all_points_h = torch.cat([all_points,torch.ones(all_points.shape[:-1]+(1,)).to(all_points.device)],dim=-1)
|
380 |
+
all_points_original = all_points_h @ batch['input_cams']['c2ws_original'][0][0].T
|
381 |
+
all_points = all_points_original[...,:3]
|
382 |
+
|
383 |
+
# uncomment this to visualize a simple TSDF
|
384 |
+
if tsdf:
|
385 |
+
fused_meshes = fuse_batch(pred,gt,batch,voxel_size=0.002)
|
386 |
+
else:
|
387 |
+
fused_meshes = None
|
388 |
+
|
389 |
+
if visualize:
|
390 |
+
just_load_viz(pred, gt, batch, addr=rr_addr,fused_meshes=fused_meshes)
|
391 |
+
return all_points
|
392 |
+
|
393 |
+
|
394 |
+
def main():
|
395 |
+
parser = argparse.ArgumentParser()
|
396 |
+
parser.add_argument("data_dir", type=str)
|
397 |
+
parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876"))
|
398 |
+
parser.add_argument("--visualize", action="store_true", default=False)
|
399 |
+
parser.add_argument("--run_octmae", action="store_true", default=False)
|
400 |
+
parser.add_argument("--set_conf", type=float, default=5)
|
401 |
+
parser.add_argument("--n_pred_views", type=int, default=5)
|
402 |
+
parser.add_argument("--filter_all_masks", action="store_true", default=False)
|
403 |
+
parser.add_argument("--tsdf", action="store_true", default=False)
|
404 |
+
# ablation settings
|
405 |
+
parser.add_argument("--no_input_mask", action="store_true", default=False)
|
406 |
+
parser.add_argument("--no_pred_mask", action="store_true", default=False)
|
407 |
+
parser.add_argument("--no_filter_input_view", action="store_true", default=False)
|
408 |
+
parser.add_argument("--false_positive", type=float, default=None)
|
409 |
+
parser.add_argument("--false_negative", type=float, default=None)
|
410 |
+
args = parser.parse_args()
|
411 |
+
|
412 |
+
print("Loading checkpoint from Huggingface")
|
413 |
+
rayst3r_checkpoint = hf_hub_download("bartduis/rayst3r", "rayst3r.pth")
|
414 |
+
|
415 |
+
model = EvalWrapper(rayst3r_checkpoint,distributed=False)
|
416 |
+
all_points = eval_scene(model, args.data_dir,visualize=args.visualize,rr_addr=args.rr_addr,run_octmae=args.run_octmae,set_conf=args.set_conf,
|
417 |
+
no_input_mask=args.no_input_mask,no_pred_mask=args.no_pred_mask,no_filter_input_view=args.no_filter_input_view,false_positive=args.false_positive,
|
418 |
+
false_negative=args.false_negative,n_pred_views=args.n_pred_views,
|
419 |
+
do_filter_all_masks=args.filter_all_masks,tsdf=args.tsdf).cpu().numpy()
|
420 |
+
all_points_save = os.path.join(args.data_dir,"inference_points.ply")
|
421 |
+
o3d_pc = npy2ply(all_points,colors=None,normals=None)
|
422 |
+
o3d.io.write_point_cloud(all_points_save, o3d_pc)
|
423 |
+
|
424 |
+
if __name__ == "__main__":
|
425 |
+
main()
|
eval_wrapper/eval_utils.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from scipy.stats import norm, lognorm
|
4 |
+
import torch
|
5 |
+
import open3d as o3d
|
6 |
+
|
7 |
+
def colorize_points_with_turbo_all_dims(points, method='norm',cmap='turbo'):
|
8 |
+
"""
|
9 |
+
Assigns colors to 3D points using the 'turbo' colormap based on a scalar computed from all 3 dimensions.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
points (np.ndarray): (N, 3) array of 3D points.
|
13 |
+
method (str): Method for reducing 3D point to scalar. Options: 'norm', 'pca'.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
np.ndarray: (N, 3) RGB colors in [0, 1].
|
17 |
+
"""
|
18 |
+
assert points.shape[1] == 3, "Input must be of shape (N, 3)"
|
19 |
+
|
20 |
+
if method == 'norm':
|
21 |
+
scalar = np.linalg.norm(points, axis=1)
|
22 |
+
elif method == 'pca':
|
23 |
+
# Project onto first principal component
|
24 |
+
mean = points.mean(axis=0)
|
25 |
+
centered = points - mean
|
26 |
+
u, s, vh = np.linalg.svd(centered, full_matrices=False)
|
27 |
+
scalar = centered @ vh[0] # Project onto first principal axis
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown method '{method}'")
|
30 |
+
|
31 |
+
# Normalize scalar to [0, 1]
|
32 |
+
scalar_min, scalar_max = scalar.min(), scalar.max()
|
33 |
+
normalized = (scalar - scalar_min) / (scalar_max - scalar_min + 1e-8)
|
34 |
+
|
35 |
+
# Apply turbo colormap
|
36 |
+
cmap = plt.colormaps.get_cmap(cmap)
|
37 |
+
colors = cmap(normalized)[:, :3] # Drop alpha
|
38 |
+
|
39 |
+
return colors
|
40 |
+
|
41 |
+
def npy2ply(points,colors=None,normals=None):
|
42 |
+
cloud = o3d.geometry.PointCloud()
|
43 |
+
cloud.points = o3d.utility.Vector3dVector(points.astype(np.float64))
|
44 |
+
|
45 |
+
# compute the normals
|
46 |
+
if colors is not None:
|
47 |
+
if colors.max()>1:
|
48 |
+
colors = colors/255.0
|
49 |
+
cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
|
50 |
+
else:
|
51 |
+
colors = colorize_points_with_turbo_all_dims(points)
|
52 |
+
cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
|
53 |
+
if normals is not None:
|
54 |
+
cloud.normals = o3d.utility.Vector3dVector(normals.astype(np.float64))
|
55 |
+
return cloud
|
56 |
+
|
57 |
+
def transform_pointmap(pointmap_cam,c2w):
|
58 |
+
# pointmap: shape H x W x 3
|
59 |
+
# cw2: shape 4 x 4
|
60 |
+
# we want to transform the pointmap to the world frame
|
61 |
+
pointmap_cam_h = torch.cat([pointmap_cam,torch.ones(pointmap_cam.shape[:-1]+(1,)).to(pointmap_cam.device)],dim=-1)
|
62 |
+
pointmap_world_h = pointmap_cam_h @ c2w.T
|
63 |
+
pointmap_world = pointmap_world_h[...,:3]/pointmap_world_h[...,3:4]
|
64 |
+
return pointmap_world
|
65 |
+
|
66 |
+
def filter_all_masks(pred_dict, batch, max_outlier_views=1):
|
67 |
+
pred_masks = (torch.sigmoid(pred_dict['classifier'][0]).float() < 0.5).bool() # [V, H, W]
|
68 |
+
n_views, H, W = pred_masks.shape
|
69 |
+
device = pred_masks.device
|
70 |
+
|
71 |
+
K = batch['input_cams']['Ks'][0][0] # [3, 3]
|
72 |
+
c2ws = batch['new_cams']['c2ws'][0] # [V, 4, 4]
|
73 |
+
w2cs = torch.linalg.inv(c2ws) # [V, 4, 4]
|
74 |
+
|
75 |
+
pointmaps = pred_dict['pointmaps'][0] # [V, H, W, 3]
|
76 |
+
pointmaps_h = torch.cat([pointmaps, torch.ones_like(pointmaps[..., :1])], dim=-1) # [V, H, W, 4]
|
77 |
+
|
78 |
+
visibility_count = torch.zeros((n_views, H, W), dtype=torch.int32, device=device)
|
79 |
+
|
80 |
+
for j in range(n_views):
|
81 |
+
# Project pointmap j to all other views i ≠ j
|
82 |
+
pmap_h = pointmaps_h[j] # [H, W, 4], world-space points from view j
|
83 |
+
pmap_h = pmap_h.view(1, H, W, 4).expand(n_views, -1, -1, -1) # [V, H, W, 4]
|
84 |
+
|
85 |
+
# Compute T_{i←j} = w2cs[i] @ c2ws[j]
|
86 |
+
T = w2cs @ c2ws[j] # [V, 4, 4]
|
87 |
+
T = T.view(n_views, 1, 1, 4, 4) # [V, 1, 1, 4, 4]
|
88 |
+
|
89 |
+
# Transform to i-th camera frame
|
90 |
+
pts_cam = torch.matmul(T, pmap_h.unsqueeze(-1)).squeeze(-1)[..., :3] # [V, H, W, 3]
|
91 |
+
|
92 |
+
# Project to image
|
93 |
+
img_coords = torch.matmul(pts_cam, K.T) # [V, H, W, 3]
|
94 |
+
img_coords = img_coords[..., :2] / img_coords[..., 2:3].clamp(min=1e-6)
|
95 |
+
img_coords = img_coords.round().long() # [V, H, W, 2]
|
96 |
+
|
97 |
+
x = img_coords[..., 0].clamp(0, W - 1)
|
98 |
+
y = img_coords[..., 1].clamp(0, H - 1)
|
99 |
+
valid = (img_coords[..., 0] >= 0) & (img_coords[..., 0] < W) & \
|
100 |
+
(img_coords[..., 1] >= 0) & (img_coords[..., 1] < H)
|
101 |
+
|
102 |
+
# Get depth of the reprojected point from j into i
|
103 |
+
reprojected_depth = pts_cam[..., 2] # [V, H, W]
|
104 |
+
|
105 |
+
# Get depth of each view's original pointmap
|
106 |
+
target_depth = pointmaps[:, :, :, 2] # [V, H, W]
|
107 |
+
|
108 |
+
# Lookup the depth value in view i at the projected location (x, y)
|
109 |
+
depth_at_pixel = target_depth[torch.arange(n_views).view(-1, 1, 1), y, x] # [V, H, W]
|
110 |
+
|
111 |
+
# Check that the point is in front (closest along ray)
|
112 |
+
is_closest = reprojected_depth < depth_at_pixel # [V, H, W]
|
113 |
+
|
114 |
+
# Lookup mask values at projected location
|
115 |
+
projected_mask = pred_masks[torch.arange(n_views).view(-1, 1, 1), y, x] & valid # [V, H, W]
|
116 |
+
|
117 |
+
# Only consider as visible if it’s within mask and closest point
|
118 |
+
visible = projected_mask & is_closest # [V, H, W]
|
119 |
+
|
120 |
+
# Count how many views see each pixel from j
|
121 |
+
visibility_count[j] = visible.sum(dim=0)
|
122 |
+
|
123 |
+
visibility_mask = (visibility_count <= max_outlier_views).bool()
|
124 |
+
batch['new_cams']['valid_masks'] = visibility_mask & batch['new_cams']['valid_masks']
|
125 |
+
return batch
|
eval_wrapper/sample_poses.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import open3d as o3d
|
4 |
+
|
5 |
+
|
6 |
+
def look_at(cam_pos, target=(0,0,0), up=(0,0,1)):
|
7 |
+
# Forward vector
|
8 |
+
forward = target - cam_pos
|
9 |
+
forward /= np.linalg.norm(forward)
|
10 |
+
|
11 |
+
# Default up vector
|
12 |
+
right = np.cross(up, forward)
|
13 |
+
if np.linalg.norm(right) < 1e-6:
|
14 |
+
up = np.array([1, 0, 0])
|
15 |
+
right = np.cross(up, forward)
|
16 |
+
|
17 |
+
right /= np.linalg.norm(right)
|
18 |
+
up = np.cross(forward, right)
|
19 |
+
|
20 |
+
# Build rotation and translation matrices
|
21 |
+
rotation = np.eye(4)
|
22 |
+
rotation[:3, :3] = np.vstack([right, up, -forward]).T
|
23 |
+
|
24 |
+
|
25 |
+
translation = np.eye(4)
|
26 |
+
translation[:3, 3] = cam_pos
|
27 |
+
|
28 |
+
cam_to_world = translation @ rotation
|
29 |
+
cam_to_world[:3,2] = -cam_to_world[:3,2]
|
30 |
+
cam_to_world[:3,1] = -cam_to_world[:3,1]
|
31 |
+
# rotate 90 degrees around z axis
|
32 |
+
return cam_to_world
|
33 |
+
|
34 |
+
|
35 |
+
def sample_camera_poses(target: np.ndarray, inner_radius: float, outer_radius: float, n: int,seed: int = 42,mode: str = 'grid') -> np.ndarray:
|
36 |
+
"""
|
37 |
+
Samples `n` camera poses uniformly on a sphere of given `radius` around `target`.
|
38 |
+
The cameras are positioned randomly and oriented to look at `target`.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
target (np.ndarray): 3D point (x, y, z) that cameras should look at.
|
42 |
+
inner_radius (float): Radius of the sphere.
|
43 |
+
outer_radius (float): Radius of the sphere.
|
44 |
+
n (int): Number of camera poses to sample.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
torch.Tensor: (n, 4, 4) array of transformation matrices (camera-to-world).
|
48 |
+
"""
|
49 |
+
cameras = []
|
50 |
+
np.random.seed(seed)
|
51 |
+
|
52 |
+
u_1 = np.linspace(0,1,n,endpoint=False)
|
53 |
+
u_2 = np.linspace(0,0.7,n)
|
54 |
+
u_1, u_2 = np.meshgrid(u_1, u_2)
|
55 |
+
u_1 = u_1.flatten()
|
56 |
+
u_2 = u_2.flatten()
|
57 |
+
theta = np.arccos(1-2*u_2)
|
58 |
+
phi = 2*np.pi*u_1
|
59 |
+
n_poses = len(phi)
|
60 |
+
|
61 |
+
radii = np.random.uniform(inner_radius, outer_radius, n_poses)
|
62 |
+
cameras = []
|
63 |
+
|
64 |
+
r_z = np.array([[0,-1,0],[1,0,0],[0,0,1]])
|
65 |
+
|
66 |
+
for i in range(n_poses):
|
67 |
+
# Camera position on the sphere
|
68 |
+
x = target[0] + radii[i] * np.sin(theta[i]) * np.cos(phi[i])
|
69 |
+
y = target[1] + radii[i] * np.sin(theta[i]) * np.sin(phi[i])
|
70 |
+
z = target[2] + radii[i] * np.cos(theta[i])
|
71 |
+
cam_pos = np.array([x, y, z])
|
72 |
+
cam2world = look_at(cam_pos, target)
|
73 |
+
if theta[i] == 0:
|
74 |
+
cam2world[:3,:3] = cam2world[:3,:3] @ r_z # rotate 90 degrees around z axis for the camera opposite to the input
|
75 |
+
cameras.append(cam2world)
|
76 |
+
cameras = np.unique(cameras, axis=0)
|
77 |
+
return np.stack(cameras)
|
78 |
+
|
79 |
+
|
80 |
+
def pointmap_to_poses(pointmaps: torch.Tensor, n_poses: int, inner_radius: float = 1.1, outer_radius: float = 2.5, device: str = 'cuda',
|
81 |
+
bb_mode: str='bb',run_octmae: bool = False) -> np.ndarray:
|
82 |
+
"""
|
83 |
+
Samples `n_poses` camera poses uniformly on a sphere of given `radius` around `target`.
|
84 |
+
The cameras are positioned randomly and oriented to look at `target`.
|
85 |
+
"""
|
86 |
+
|
87 |
+
bb_min_corner = pointmaps.min(dim=0)[0].cpu().numpy()
|
88 |
+
bb_max_corner = pointmaps.max(dim=0)[0].cpu().numpy()
|
89 |
+
center = (bb_min_corner + bb_max_corner) / 2 #inner_radius = inner_radius * np.linalg.norm(bb_max_corner - bb_min_corner) / 2 # minimum radius is scalar multiple of bounding box radius
|
90 |
+
bb_radius = np.linalg.norm(bb_max_corner - bb_min_corner) / 2
|
91 |
+
cam2center_dist = np.linalg.norm(center)
|
92 |
+
|
93 |
+
if run_octmae:
|
94 |
+
radius = max(1.2*cam2center_dist,2.5*bb_radius)
|
95 |
+
else:
|
96 |
+
radius = max(0.7*cam2center_dist,1.3*bb_radius)
|
97 |
+
inner_radius = radius
|
98 |
+
outer_radius = radius
|
99 |
+
camera_poses = sample_camera_poses(center, inner_radius, outer_radius, n_poses)
|
100 |
+
return camera_poses
|
example_scene/cam2world.pt
ADDED
Binary file (1.25 kB). View file
|
|
example_scene/intrinsics.pt
ADDED
Binary file (1.2 kB). View file
|
|
extensions/curope/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
from .curope2d import cuRoPE2D
|
extensions/curope/curope.cpp
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
3 |
+
Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
*/
|
5 |
+
|
6 |
+
#include <torch/extension.h>
|
7 |
+
|
8 |
+
// forward declaration
|
9 |
+
void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
|
10 |
+
|
11 |
+
void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
|
12 |
+
{
|
13 |
+
const int B = tokens.size(0);
|
14 |
+
const int N = tokens.size(1);
|
15 |
+
const int H = tokens.size(2);
|
16 |
+
const int D = tokens.size(3) / 4;
|
17 |
+
|
18 |
+
auto tok = tokens.accessor<float, 4>();
|
19 |
+
auto pos = positions.accessor<int64_t, 3>();
|
20 |
+
|
21 |
+
for (int b = 0; b < B; b++) {
|
22 |
+
for (int x = 0; x < 2; x++) { // y and then x (2d)
|
23 |
+
for (int n = 0; n < N; n++) {
|
24 |
+
|
25 |
+
// grab the token position
|
26 |
+
const int p = pos[b][n][x];
|
27 |
+
|
28 |
+
for (int h = 0; h < H; h++) {
|
29 |
+
for (int d = 0; d < D; d++) {
|
30 |
+
// grab the two values
|
31 |
+
float u = tok[b][n][h][d+0+x*2*D];
|
32 |
+
float v = tok[b][n][h][d+D+x*2*D];
|
33 |
+
|
34 |
+
// grab the cos,sin
|
35 |
+
const float inv_freq = fwd * p / powf(base, d/float(D));
|
36 |
+
float c = cosf(inv_freq);
|
37 |
+
float s = sinf(inv_freq);
|
38 |
+
|
39 |
+
// write the result
|
40 |
+
tok[b][n][h][d+0+x*2*D] = u*c - v*s;
|
41 |
+
tok[b][n][h][d+D+x*2*D] = v*c + u*s;
|
42 |
+
}
|
43 |
+
}
|
44 |
+
}
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
void rope_2d( torch::Tensor tokens, // B,N,H,D
|
50 |
+
const torch::Tensor positions, // B,N,2
|
51 |
+
const float base,
|
52 |
+
const float fwd )
|
53 |
+
{
|
54 |
+
TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
|
55 |
+
TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
|
56 |
+
TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
|
57 |
+
TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
|
58 |
+
TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
|
59 |
+
TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
|
60 |
+
|
61 |
+
if (tokens.is_cuda())
|
62 |
+
rope_2d_cuda( tokens, positions, base, fwd );
|
63 |
+
else
|
64 |
+
rope_2d_cpu( tokens, positions, base, fwd );
|
65 |
+
}
|
66 |
+
|
67 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
68 |
+
m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
|
69 |
+
}
|
extensions/curope/curope.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: curope
|
3 |
+
Version: 0.0.0
|
4 |
+
Summary: UNKNOWN
|
5 |
+
Home-page: UNKNOWN
|
6 |
+
License: UNKNOWN
|
7 |
+
Platform: UNKNOWN
|
8 |
+
|
9 |
+
UNKNOWN
|
10 |
+
|
extensions/curope/curope.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__init__.py
|
2 |
+
curope.cpp
|
3 |
+
curope2d.py
|
4 |
+
kernels.cu
|
5 |
+
setup.py
|
6 |
+
curope.egg-info/PKG-INFO
|
7 |
+
curope.egg-info/SOURCES.txt
|
8 |
+
curope.egg-info/dependency_links.txt
|
9 |
+
curope.egg-info/top_level.txt
|
extensions/curope/curope.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
extensions/curope/curope.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
curope
|
extensions/curope/curope2d.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
import curope as _kernels # run `python setup.py install`
|
8 |
+
except ModuleNotFoundError:
|
9 |
+
from . import curope as _kernels # run `python setup.py build_ext --inplace`
|
10 |
+
|
11 |
+
from torch.amp import custom_fwd, custom_bwd
|
12 |
+
|
13 |
+
class cuRoPE2D_func (torch.autograd.Function):
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
|
17 |
+
def forward(ctx, tokens, positions, base, F0=1):
|
18 |
+
ctx.save_for_backward(positions)
|
19 |
+
ctx.saved_base = base
|
20 |
+
ctx.saved_F0 = F0
|
21 |
+
# tokens = tokens.clone() # uncomment this if inplace doesn't work
|
22 |
+
_kernels.rope_2d( tokens, positions, base, F0 )
|
23 |
+
ctx.mark_dirty(tokens)
|
24 |
+
return tokens
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
@custom_bwd(device_type='cuda')
|
28 |
+
def backward(ctx, grad_res):
|
29 |
+
positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
|
30 |
+
_kernels.rope_2d( grad_res, positions, base, -F0 )
|
31 |
+
ctx.mark_dirty(grad_res)
|
32 |
+
return grad_res, None, None, None
|
33 |
+
|
34 |
+
|
35 |
+
class cuRoPE2D(torch.nn.Module):
|
36 |
+
def __init__(self, freq=100.0, F0=1.0):
|
37 |
+
super().__init__()
|
38 |
+
self.base = freq
|
39 |
+
self.F0 = F0
|
40 |
+
|
41 |
+
def forward(self, tokens, positions):
|
42 |
+
cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
|
43 |
+
return tokens
|
extensions/curope/kernels.cu
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
3 |
+
Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
*/
|
5 |
+
|
6 |
+
#include <torch/extension.h>
|
7 |
+
#include <cuda.h>
|
8 |
+
#include <cuda_runtime.h>
|
9 |
+
#include <vector>
|
10 |
+
|
11 |
+
#define CHECK_CUDA(tensor) {\
|
12 |
+
TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
|
13 |
+
TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
|
14 |
+
void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
|
15 |
+
|
16 |
+
|
17 |
+
template < typename scalar_t >
|
18 |
+
__global__ void rope_2d_cuda_kernel(
|
19 |
+
//scalar_t* __restrict__ tokens,
|
20 |
+
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> tokens,
|
21 |
+
const int64_t* __restrict__ pos,
|
22 |
+
const float base,
|
23 |
+
const float fwd )
|
24 |
+
// const int N, const int H, const int D )
|
25 |
+
{
|
26 |
+
// tokens shape = (B, N, H, D)
|
27 |
+
const int N = tokens.size(1);
|
28 |
+
const int H = tokens.size(2);
|
29 |
+
const int D = tokens.size(3);
|
30 |
+
|
31 |
+
// each block update a single token, for all heads
|
32 |
+
// each thread takes care of a single output
|
33 |
+
extern __shared__ float shared[];
|
34 |
+
float* shared_inv_freq = shared + D;
|
35 |
+
|
36 |
+
const int b = blockIdx.x / N;
|
37 |
+
const int n = blockIdx.x % N;
|
38 |
+
|
39 |
+
const int Q = D / 4;
|
40 |
+
// one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
|
41 |
+
// u_Y v_Y u_X v_X
|
42 |
+
|
43 |
+
// shared memory: first, compute inv_freq
|
44 |
+
if (threadIdx.x < Q)
|
45 |
+
shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
|
46 |
+
__syncthreads();
|
47 |
+
|
48 |
+
// start of X or Y part
|
49 |
+
const int X = threadIdx.x < D/2 ? 0 : 1;
|
50 |
+
const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
|
51 |
+
|
52 |
+
// grab the cos,sin appropriate for me
|
53 |
+
const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
|
54 |
+
const float cos = cosf(freq);
|
55 |
+
const float sin = sinf(freq);
|
56 |
+
/*
|
57 |
+
float* shared_cos_sin = shared + D + D/4;
|
58 |
+
if ((threadIdx.x % (D/2)) < Q)
|
59 |
+
shared_cos_sin[m+0] = cosf(freq);
|
60 |
+
else
|
61 |
+
shared_cos_sin[m+Q] = sinf(freq);
|
62 |
+
__syncthreads();
|
63 |
+
const float cos = shared_cos_sin[m+0];
|
64 |
+
const float sin = shared_cos_sin[m+Q];
|
65 |
+
*/
|
66 |
+
|
67 |
+
for (int h = 0; h < H; h++)
|
68 |
+
{
|
69 |
+
// then, load all the token for this head in shared memory
|
70 |
+
shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
|
71 |
+
__syncthreads();
|
72 |
+
|
73 |
+
const float u = shared[m];
|
74 |
+
const float v = shared[m+Q];
|
75 |
+
|
76 |
+
// write output
|
77 |
+
if ((threadIdx.x % (D/2)) < Q)
|
78 |
+
tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
|
79 |
+
else
|
80 |
+
tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
|
85 |
+
{
|
86 |
+
const int B = tokens.size(0); // batch size
|
87 |
+
const int N = tokens.size(1); // sequence length
|
88 |
+
const int H = tokens.size(2); // number of heads
|
89 |
+
const int D = tokens.size(3); // dimension per head
|
90 |
+
|
91 |
+
TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
|
92 |
+
TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
|
93 |
+
TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
|
94 |
+
TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
|
95 |
+
|
96 |
+
// one block for each layer, one thread per local-max
|
97 |
+
const int THREADS_PER_BLOCK = D;
|
98 |
+
const int N_BLOCKS = B * N; // each block takes care of H*D values
|
99 |
+
const int SHARED_MEM = sizeof(float) * (D + D/4);
|
100 |
+
|
101 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
|
102 |
+
rope_2d_cuda_kernel<scalar_t> <<<N_BLOCKS, THREADS_PER_BLOCK, SHARED_MEM>>> (
|
103 |
+
//tokens.data_ptr<scalar_t>(),
|
104 |
+
tokens.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
|
105 |
+
pos.data_ptr<int64_t>(),
|
106 |
+
base, fwd); //, N, H, D );
|
107 |
+
}));
|
108 |
+
}
|
extensions/curope/setup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
from setuptools import setup
|
5 |
+
from torch import cuda
|
6 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
7 |
+
|
8 |
+
# compile for all possible CUDA architectures
|
9 |
+
all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
|
10 |
+
# alternatively, you can list cuda archs that you want, eg:
|
11 |
+
# all_cuda_archs = [
|
12 |
+
# '-gencode', 'arch=compute_70,code=sm_70',
|
13 |
+
# '-gencode', 'arch=compute_75,code=sm_75',
|
14 |
+
# '-gencode', 'arch=compute_80,code=sm_80',
|
15 |
+
# '-gencode', 'arch=compute_86,code=sm_86'
|
16 |
+
# ]
|
17 |
+
|
18 |
+
setup(
|
19 |
+
name = 'curope',
|
20 |
+
ext_modules = [
|
21 |
+
CUDAExtension(
|
22 |
+
name='curope',
|
23 |
+
sources=[
|
24 |
+
"curope.cpp",
|
25 |
+
"kernels.cu",
|
26 |
+
],
|
27 |
+
extra_compile_args = dict(
|
28 |
+
nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
|
29 |
+
cxx=['-O3'])
|
30 |
+
)
|
31 |
+
],
|
32 |
+
cmdclass = {
|
33 |
+
'build_ext': BuildExtension
|
34 |
+
})
|
input/cam2world.pt
ADDED
Binary file (1.25 kB). View file
|
|
input/intrinsics.pt
ADDED
Binary file (1.07 kB). View file
|
|
main.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bb = breakpoint
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import wandb
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from datasets.octmae import OctMae
|
7 |
+
from datasets.foundation_pose import FoundationPose
|
8 |
+
from datasets.generic_loader import GenericLoader
|
9 |
+
|
10 |
+
from utils.collate import collate
|
11 |
+
from models.rayquery import RayQuery
|
12 |
+
from engine import train_epoch, eval_epoch, eval_model
|
13 |
+
import torch.nn as nn
|
14 |
+
from models.rayquery import RayQuery, PointmapEncoder, RayEncoder
|
15 |
+
from models.losses import *
|
16 |
+
import utils.misc as misc
|
17 |
+
import os
|
18 |
+
from utils.viz import just_load_viz
|
19 |
+
from utils.fusion import fuse_batch
|
20 |
+
import socket
|
21 |
+
import time
|
22 |
+
from utils.augmentations import *
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
parser = ArgumentParser()
|
26 |
+
parser.add_argument("--dataset_train", type=str, default="TableOfCubes(size=10,n_views=2,seed=747)")
|
27 |
+
parser.add_argument("--dataset_test", type=str, default="TableOfCubes(size=10,n_views=2,seed=787)")
|
28 |
+
parser.add_argument("--dataset_just_load", type=str, default=None)
|
29 |
+
parser.add_argument("--logdir", type=str, default="logs")
|
30 |
+
parser.add_argument("--batch_size", type=int, default=5)
|
31 |
+
parser.add_argument("--n_epochs", type=int, default=100)
|
32 |
+
parser.add_argument("--n_workers", type=int, default=4)
|
33 |
+
parser.add_argument("--model", type=str, default="RayQuery(ray_enc=RayEncoder(),pointmap_enc=PointmapEncoder(),criterion=RayCompletion(ConfLoss(L21)))")
|
34 |
+
parser.add_argument("--save_every", type=int, default=1)
|
35 |
+
parser.add_argument("--resume", type=str, default=None)
|
36 |
+
parser.add_argument("--eval_every", type=int, default=3)
|
37 |
+
parser.add_argument("--wandb_project", type=str, default=None)
|
38 |
+
parser.add_argument("--wandb_run_name", type=str, default="init")
|
39 |
+
parser.add_argument("--just_load", action="store_true")
|
40 |
+
parser.add_argument("--device", type=str, default="cuda")
|
41 |
+
parser.add_argument("--rr_addr", type=str, default="0.0.0.0:"+os.getenv("RERUN_RECORDING","9876"))
|
42 |
+
parser.add_argument("--mesh", action="store_true")
|
43 |
+
parser.add_argument("--max_norm", type=float, default=-1)
|
44 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
|
45 |
+
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
|
46 |
+
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
47 |
+
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
48 |
+
help='lower lr bound for cyclic schedulers that hit 0')
|
49 |
+
parser.add_argument('--warmup_epochs', type=int, default=10)
|
50 |
+
parser.add_argument('--weight_decay', type=float, default=0.01)
|
51 |
+
parser.add_argument('--normalize_mode',type=str,default='None')
|
52 |
+
parser.add_argument('--start_from',type=str,default=None)
|
53 |
+
parser.add_argument('--augmentor',type=str,default='None')
|
54 |
+
return parser.parse_args()
|
55 |
+
|
56 |
+
def main(args):
|
57 |
+
load_dino = False
|
58 |
+
if not args.just_load:
|
59 |
+
dataset_train = eval(args.dataset_train)
|
60 |
+
dataset_test = eval(args.dataset_test)
|
61 |
+
if not dataset_train.prefetch_dino:
|
62 |
+
load_dino = True
|
63 |
+
rank, world_size, local_rank = misc.setup_distributed()
|
64 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
65 |
+
dataset_train, num_replicas=world_size, rank=rank, shuffle=True
|
66 |
+
)
|
67 |
+
|
68 |
+
sampler_test = torch.utils.data.DistributedSampler(
|
69 |
+
dataset_test, num_replicas=world_size, rank=rank, shuffle=False
|
70 |
+
)
|
71 |
+
|
72 |
+
train_loader = DataLoader(
|
73 |
+
dataset_train, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
|
74 |
+
num_workers=args.n_workers,
|
75 |
+
pin_memory=True,
|
76 |
+
prefetch_factor=2,
|
77 |
+
drop_last=True
|
78 |
+
)
|
79 |
+
test_loader = DataLoader(
|
80 |
+
dataset_test, sampler=sampler_test, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
|
81 |
+
num_workers=args.n_workers,
|
82 |
+
pin_memory=True,
|
83 |
+
prefetch_factor=2,
|
84 |
+
drop_last=True
|
85 |
+
)
|
86 |
+
|
87 |
+
n_scenes_epoch = len(train_loader) * args.batch_size * world_size
|
88 |
+
print(f"Number of scenes in epoch: {n_scenes_epoch}")
|
89 |
+
else:
|
90 |
+
if args.dataset_just_load is None:
|
91 |
+
dataset = eval(args.dataset_train)
|
92 |
+
else:
|
93 |
+
dataset = eval(args.dataset_just_load)
|
94 |
+
if not dataset.prefetch_dino:
|
95 |
+
load_dino = True
|
96 |
+
rank, world_size, local_rank = misc.setup_distributed()
|
97 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
98 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=False
|
99 |
+
)
|
100 |
+
just_loader = DataLoader(dataset, sampler=sampler_train, batch_size=args.batch_size, shuffle=False, collate_fn=collate,
|
101 |
+
pin_memory=True,
|
102 |
+
drop_last=True
|
103 |
+
)
|
104 |
+
|
105 |
+
model = eval(args.model).to(args.device)
|
106 |
+
if args.augmentor != 'None':
|
107 |
+
augmentor = eval(args.augmentor)
|
108 |
+
else:
|
109 |
+
augmentor = None
|
110 |
+
|
111 |
+
if load_dino and len(model.dino_layers) > 0:
|
112 |
+
dino_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vitl14_reg")
|
113 |
+
dino_model.eval()
|
114 |
+
dino_model.to("cuda")
|
115 |
+
else:
|
116 |
+
dino_model = None
|
117 |
+
# distribute model
|
118 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=True)
|
119 |
+
model_without_ddp = model.module if hasattr(model, 'module') else model
|
120 |
+
|
121 |
+
eff_batch_size = args.batch_size * misc.get_world_size()
|
122 |
+
if args.lr is None: # only base_lr is specified
|
123 |
+
args.lr = args.blr * eff_batch_size / 256
|
124 |
+
|
125 |
+
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
|
126 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
127 |
+
os.makedirs(args.logdir,exist_ok=True)
|
128 |
+
start_epoch = 0
|
129 |
+
print("Running on host %s" % socket.gethostname())
|
130 |
+
if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-latest.pth")):
|
131 |
+
checkpoint = torch.load(os.path.join(args.resume, "checkpoint-latest.pth"), map_location='cpu')
|
132 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
133 |
+
model_params = list(model.parameters())
|
134 |
+
print("Resume checkpoint %s" % args.resume)
|
135 |
+
|
136 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
137 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
138 |
+
start_epoch = checkpoint['epoch'] + 1
|
139 |
+
print("With optim & sched!")
|
140 |
+
del checkpoint
|
141 |
+
elif args.start_from is not None:
|
142 |
+
checkpoint = torch.load(args.start_from, map_location='cpu')
|
143 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
144 |
+
print("Start from checkpoint %s" % args.start_from)
|
145 |
+
if args.just_load:
|
146 |
+
with torch.no_grad():
|
147 |
+
while True:
|
148 |
+
#test_log_dict = eval_epoch(model,just_loader,device=args.device,dino_model=dino_model,args=args)
|
149 |
+
for data in just_loader:
|
150 |
+
pred, gt, loss_dict, batch = eval_model(model,data,mode='viz',args=args,dino_model=dino_model,augmentor=augmentor)
|
151 |
+
# cast to float32 for visualization
|
152 |
+
gt = {k: v.float() for k, v in gt.items()}
|
153 |
+
pred = {k: v.float() for k, v in pred.items()}
|
154 |
+
#loss_dict = eval_model(model,data,mode='loss',device=args.device)
|
155 |
+
#print(f"Loss: {loss_dict['loss']:.4f}")
|
156 |
+
# summarize all keys in loss_dict in table
|
157 |
+
print(f"{'Key':<10} {'Value':<10}")
|
158 |
+
print("-"*20)
|
159 |
+
for key, value in loss_dict.items():
|
160 |
+
print(f"{key:<10}: {value:.4f}")
|
161 |
+
print("-"*20)
|
162 |
+
name = args.logdir
|
163 |
+
addr = args.rr_addr
|
164 |
+
if args.mesh:
|
165 |
+
fused_meshes = fuse_batch(pred,gt,data, voxel_size=0.002)
|
166 |
+
else:
|
167 |
+
fused_meshes = None
|
168 |
+
just_load_viz(pred,gt,batch,addr=addr,name=name,fused_meshes=fused_meshes)
|
169 |
+
breakpoint()
|
170 |
+
return
|
171 |
+
else:
|
172 |
+
if args.wandb_project and misc.get_rank() == 0:
|
173 |
+
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
|
174 |
+
log_wandb = args.wandb_project
|
175 |
+
else:
|
176 |
+
log_wandb = None
|
177 |
+
for epoch in range(start_epoch,args.n_epochs):
|
178 |
+
start_time = time.time()
|
179 |
+
log_dict = train_epoch(model,train_loader,optimizer,device=args.device,max_norm=args.max_norm,epoch=epoch,
|
180 |
+
log_wandb=log_wandb,batch_size=eff_batch_size,args=args,dino_model=dino_model,augmentor=augmentor)
|
181 |
+
end_time = time.time()
|
182 |
+
print(f"Epoch {epoch} train loss: {log_dict['loss']:.4f} grad_norm: {log_dict['grad_norm']:.4f} \n")
|
183 |
+
print(f"Time taken for epoch {epoch}: {end_time - start_time:.2f} seconds")
|
184 |
+
|
185 |
+
if epoch % args.eval_every == 0:
|
186 |
+
test_log_dict = eval_epoch(model,test_loader,device=args.device,dino_model=dino_model,args=args,augmentor=augmentor)
|
187 |
+
print(f"Epoch {epoch} test loss: {test_log_dict['loss']:.4f} \n")
|
188 |
+
if log_wandb:
|
189 |
+
wandb_dict = {f"test_{k}":v for k,v in test_log_dict.items()}
|
190 |
+
wandb.log(wandb_dict, step=(epoch+1)*n_scenes_epoch)
|
191 |
+
if epoch % args.save_every == 0:
|
192 |
+
# this saves the model every epoch and doesn't overwrite but it becomes tremendous, huge
|
193 |
+
#misc.save_model(args, epoch, model, optimizer)
|
194 |
+
misc.save_model(args, epoch, model_without_ddp, optimizer, epoch_name=f"latest")
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
args = parse_args()
|
198 |
+
main(args)
|
models/blocks.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from: https://github.com/naver/croco/blob/743ee71a2a9bf57cea6832a9064a70a0597fcfcb/models/blocks.py
|
2 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from itertools import repeat
|
9 |
+
import collections.abc
|
10 |
+
|
11 |
+
def _ntuple(n):
|
12 |
+
def parse(x):
|
13 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
14 |
+
return x
|
15 |
+
return tuple(repeat(x, n))
|
16 |
+
return parse
|
17 |
+
to_2tuple = _ntuple(2)
|
18 |
+
|
19 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
20 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
21 |
+
"""
|
22 |
+
if drop_prob == 0. or not training:
|
23 |
+
return x
|
24 |
+
keep_prob = 1 - drop_prob
|
25 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
26 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
27 |
+
if keep_prob > 0.0 and scale_by_keep:
|
28 |
+
random_tensor.div_(keep_prob)
|
29 |
+
return x * random_tensor
|
30 |
+
|
31 |
+
class DropPath(nn.Module):
|
32 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
33 |
+
"""
|
34 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
35 |
+
super(DropPath, self).__init__()
|
36 |
+
self.drop_prob = drop_prob
|
37 |
+
self.scale_by_keep = scale_by_keep
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
41 |
+
|
42 |
+
def extra_repr(self):
|
43 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
44 |
+
|
45 |
+
class Mlp(nn.Module):
|
46 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
47 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
|
48 |
+
super().__init__()
|
49 |
+
out_features = out_features or in_features
|
50 |
+
hidden_features = hidden_features or in_features
|
51 |
+
bias = to_2tuple(bias)
|
52 |
+
drop_probs = to_2tuple(drop)
|
53 |
+
|
54 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
55 |
+
self.act = act_layer()
|
56 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
57 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
58 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
x = self.fc1(x)
|
62 |
+
x = self.act(x)
|
63 |
+
x = self.drop1(x)
|
64 |
+
x = self.fc2(x)
|
65 |
+
x = self.drop2(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
class Attention(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
71 |
+
super().__init__()
|
72 |
+
self.num_heads = num_heads
|
73 |
+
head_dim = dim // num_heads
|
74 |
+
self.scale = head_dim ** -0.5
|
75 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
76 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
77 |
+
self.proj = nn.Linear(dim, dim)
|
78 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
79 |
+
self.rope = rope
|
80 |
+
|
81 |
+
def forward(self, x, xpos):
|
82 |
+
B, N, C = x.shape
|
83 |
+
|
84 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
|
85 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
86 |
+
# q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
|
87 |
+
|
88 |
+
if self.rope is not None:
|
89 |
+
q = self.rope(q, xpos)
|
90 |
+
k = self.rope(k, xpos)
|
91 |
+
|
92 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
|
96 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
return x
|
100 |
+
|
101 |
+
class Block(nn.Module):
|
102 |
+
|
103 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
104 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
|
105 |
+
super().__init__()
|
106 |
+
self.norm1 = norm_layer(dim)
|
107 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
108 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
109 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
110 |
+
self.norm2 = norm_layer(dim)
|
111 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
112 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
113 |
+
|
114 |
+
def forward(self, x, xpos):
|
115 |
+
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
116 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
117 |
+
return x
|
118 |
+
|
119 |
+
class CrossAttention(nn.Module):
|
120 |
+
|
121 |
+
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
122 |
+
super().__init__()
|
123 |
+
self.num_heads = num_heads
|
124 |
+
head_dim = dim // num_heads
|
125 |
+
self.scale = head_dim ** -0.5
|
126 |
+
|
127 |
+
self.projq = nn.Linear(dim, dim, bias=qkv_bias)
|
128 |
+
self.projk = nn.Linear(dim, dim, bias=qkv_bias)
|
129 |
+
self.projv = nn.Linear(dim, dim, bias=qkv_bias)
|
130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
131 |
+
self.proj = nn.Linear(dim, dim)
|
132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
133 |
+
|
134 |
+
self.rope = rope
|
135 |
+
|
136 |
+
def forward(self, query, key, value, qpos, kpos):
|
137 |
+
B, Nq, C = query.shape
|
138 |
+
Nk = key.shape[1]
|
139 |
+
Nv = value.shape[1]
|
140 |
+
|
141 |
+
q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
142 |
+
k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
143 |
+
v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
144 |
+
|
145 |
+
if self.rope is not None:
|
146 |
+
q = self.rope(q, qpos)
|
147 |
+
k = self.rope(k, kpos)
|
148 |
+
|
149 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
150 |
+
attn = attn.softmax(dim=-1)
|
151 |
+
attn = self.attn_drop(attn)
|
152 |
+
|
153 |
+
x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
|
154 |
+
x = self.proj(x)
|
155 |
+
x = self.proj_drop(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
class DecoderBlock(nn.Module):
|
159 |
+
|
160 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
161 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None,order='sa_ca'):
|
162 |
+
super().__init__()
|
163 |
+
self.norm1 = norm_layer(dim)
|
164 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
165 |
+
self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
166 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
167 |
+
self.norm2 = norm_layer(dim)
|
168 |
+
self.norm3 = norm_layer(dim)
|
169 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
170 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
171 |
+
self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
|
172 |
+
self.order = order
|
173 |
+
self.batch_drop_path_prob = -drop_path if drop_path < 0. else 0.
|
174 |
+
|
175 |
+
def forward(self, x, y, xpos, ypos):
|
176 |
+
if self.order == 'sa_ca':
|
177 |
+
if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
178 |
+
y_ = self.norm_y(y)
|
179 |
+
if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
|
180 |
+
if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.mlp(self.norm3(x)))
|
181 |
+
elif self.order == 'ca_sa':
|
182 |
+
y_ = self.norm_y(y)
|
183 |
+
if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
|
184 |
+
if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
185 |
+
if self.batch_drop_path_prob==0.0 or not self.training or torch.rand(1).item()>=self.batch_drop_path_prob: x = x + self.drop_path(self.mlp(self.norm3(x)))
|
186 |
+
return x, y
|
187 |
+
|
188 |
+
|
189 |
+
# patch embedding
|
190 |
+
class PositionGetter(object):
|
191 |
+
""" return positions of patches """
|
192 |
+
|
193 |
+
def __init__(self):
|
194 |
+
self.cache_positions = {}
|
195 |
+
|
196 |
+
def __call__(self, b, h, w, device):
|
197 |
+
if not (h,w) in self.cache_positions:
|
198 |
+
x = torch.arange(w, device=device)
|
199 |
+
y = torch.arange(h, device=device)
|
200 |
+
self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
|
201 |
+
pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
|
202 |
+
return pos
|
203 |
+
|
204 |
+
class PatchEmbed(nn.Module):
|
205 |
+
""" just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
|
206 |
+
|
207 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
208 |
+
super().__init__()
|
209 |
+
img_size = to_2tuple(img_size)
|
210 |
+
patch_size = to_2tuple(patch_size)
|
211 |
+
self.img_size = img_size
|
212 |
+
self.patch_size = patch_size
|
213 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
214 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
215 |
+
self.flatten = flatten
|
216 |
+
|
217 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
218 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
219 |
+
|
220 |
+
self.position_getter = PositionGetter()
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
B, C, H, W = x.shape
|
224 |
+
torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
225 |
+
torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
226 |
+
x = self.proj(x)
|
227 |
+
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
|
228 |
+
if self.flatten:
|
229 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
230 |
+
x = self.norm(x)
|
231 |
+
return x, pos
|
232 |
+
|
233 |
+
def _init_weights(self):
|
234 |
+
w = self.proj.weight.data
|
235 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
models/heads/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# head factory
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from .linear_head import LinearPts3d
|
8 |
+
from .dpt_head import create_dpt_head, create_dpt_head_mask, create_dpt_head_depth
|
9 |
+
|
10 |
+
def head_factory(head_type, output_mode, net, has_conf=False):
|
11 |
+
"""" build a prediction head for the decoder
|
12 |
+
"""
|
13 |
+
if head_type == 'linear' and output_mode == 'pts3d':
|
14 |
+
return LinearPts3d(net, has_conf)
|
15 |
+
if head_type == 'linear_depth' and output_mode == 'pts3d':
|
16 |
+
return LinearPts3d(net, has_conf,mode='depth')
|
17 |
+
if head_type == 'linear_classifier' and output_mode == 'pts3d':
|
18 |
+
return LinearPts3d(net, has_conf,mode='classifier')
|
19 |
+
elif head_type == 'dpt' and output_mode == 'pts3d':
|
20 |
+
return create_dpt_head(net, has_conf=has_conf)
|
21 |
+
elif head_type == 'dpt_depth' and output_mode == 'pts3d':
|
22 |
+
return create_dpt_head_depth(net, has_conf=has_conf)
|
23 |
+
elif head_type == 'dpt_mask' and output_mode == 'pts3d':
|
24 |
+
return create_dpt_head_mask(net, has_conf=has_conf)
|
25 |
+
else:
|
26 |
+
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
|
models/heads/dpt_head.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
from typing import Union, Tuple, Iterable, List, Optional, Dict
|
6 |
+
from .postprocess import postprocess
|
7 |
+
|
8 |
+
def pair(t):
|
9 |
+
return t if isinstance(t, tuple) else (t, t)
|
10 |
+
|
11 |
+
def make_scratch(in_shape, out_shape, groups=1, expand=False):
|
12 |
+
scratch = nn.Module()
|
13 |
+
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape
|
16 |
+
out_shape3 = out_shape
|
17 |
+
out_shape4 = out_shape
|
18 |
+
if expand == True:
|
19 |
+
out_shape1 = out_shape
|
20 |
+
out_shape2 = out_shape * 2
|
21 |
+
out_shape3 = out_shape * 4
|
22 |
+
out_shape4 = out_shape * 8
|
23 |
+
|
24 |
+
scratch.layer1_rn = nn.Conv2d(
|
25 |
+
in_shape[0],
|
26 |
+
out_shape1,
|
27 |
+
kernel_size=3,
|
28 |
+
stride=1,
|
29 |
+
padding=1,
|
30 |
+
bias=False,
|
31 |
+
groups=groups,
|
32 |
+
)
|
33 |
+
scratch.layer2_rn = nn.Conv2d(
|
34 |
+
in_shape[1],
|
35 |
+
out_shape2,
|
36 |
+
kernel_size=3,
|
37 |
+
stride=1,
|
38 |
+
padding=1,
|
39 |
+
bias=False,
|
40 |
+
groups=groups,
|
41 |
+
)
|
42 |
+
scratch.layer3_rn = nn.Conv2d(
|
43 |
+
in_shape[2],
|
44 |
+
out_shape3,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1,
|
48 |
+
bias=False,
|
49 |
+
groups=groups,
|
50 |
+
)
|
51 |
+
scratch.layer4_rn = nn.Conv2d(
|
52 |
+
in_shape[3],
|
53 |
+
out_shape4,
|
54 |
+
kernel_size=3,
|
55 |
+
stride=1,
|
56 |
+
padding=1,
|
57 |
+
bias=False,
|
58 |
+
groups=groups,
|
59 |
+
)
|
60 |
+
|
61 |
+
scratch.layer_rn = nn.ModuleList([
|
62 |
+
scratch.layer1_rn,
|
63 |
+
scratch.layer2_rn,
|
64 |
+
scratch.layer3_rn,
|
65 |
+
scratch.layer4_rn,
|
66 |
+
])
|
67 |
+
|
68 |
+
return scratch
|
69 |
+
|
70 |
+
class ResidualConvUnit_custom(nn.Module):
|
71 |
+
"""Residual convolution module."""
|
72 |
+
|
73 |
+
def __init__(self, features, activation, bn):
|
74 |
+
"""Init.
|
75 |
+
Args:
|
76 |
+
features (int): number of features
|
77 |
+
"""
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
self.bn = bn
|
81 |
+
self.groups = 1
|
82 |
+
|
83 |
+
self.conv1 = nn.Conv2d(
|
84 |
+
features,
|
85 |
+
features,
|
86 |
+
kernel_size=3,
|
87 |
+
stride=1,
|
88 |
+
padding=1,
|
89 |
+
bias=not self.bn,
|
90 |
+
groups=self.groups,
|
91 |
+
)
|
92 |
+
|
93 |
+
self.conv2 = nn.Conv2d(
|
94 |
+
features,
|
95 |
+
features,
|
96 |
+
kernel_size=3,
|
97 |
+
stride=1,
|
98 |
+
padding=1,
|
99 |
+
bias=not self.bn,
|
100 |
+
groups=self.groups,
|
101 |
+
)
|
102 |
+
|
103 |
+
if self.bn == True:
|
104 |
+
self.bn1 = nn.BatchNorm2d(features)
|
105 |
+
self.bn2 = nn.BatchNorm2d(features)
|
106 |
+
|
107 |
+
self.activation = activation
|
108 |
+
|
109 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
"""Forward pass.
|
113 |
+
Args:
|
114 |
+
x (tensor): input
|
115 |
+
Returns:
|
116 |
+
tensor: output
|
117 |
+
"""
|
118 |
+
|
119 |
+
out = self.activation(x)
|
120 |
+
out = self.conv1(out)
|
121 |
+
if self.bn == True:
|
122 |
+
out = self.bn1(out)
|
123 |
+
|
124 |
+
out = self.activation(out)
|
125 |
+
out = self.conv2(out)
|
126 |
+
if self.bn == True:
|
127 |
+
out = self.bn2(out)
|
128 |
+
|
129 |
+
if self.groups > 1:
|
130 |
+
out = self.conv_merge(out)
|
131 |
+
|
132 |
+
return self.skip_add.add(out, x)
|
133 |
+
|
134 |
+
class FeatureFusionBlock_custom(nn.Module):
|
135 |
+
"""Feature fusion block."""
|
136 |
+
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
features,
|
140 |
+
activation,
|
141 |
+
deconv=False,
|
142 |
+
bn=False,
|
143 |
+
expand=False,
|
144 |
+
align_corners=True,
|
145 |
+
width_ratio=1,
|
146 |
+
):
|
147 |
+
"""Init.
|
148 |
+
Args:
|
149 |
+
features (int): number of features
|
150 |
+
"""
|
151 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
152 |
+
self.width_ratio = width_ratio
|
153 |
+
|
154 |
+
self.deconv = deconv
|
155 |
+
self.align_corners = align_corners
|
156 |
+
|
157 |
+
self.groups = 1
|
158 |
+
|
159 |
+
self.expand = expand
|
160 |
+
out_features = features
|
161 |
+
if self.expand == True:
|
162 |
+
out_features = features // 2
|
163 |
+
|
164 |
+
self.out_conv = nn.Conv2d(
|
165 |
+
features,
|
166 |
+
out_features,
|
167 |
+
kernel_size=1,
|
168 |
+
stride=1,
|
169 |
+
padding=0,
|
170 |
+
bias=True,
|
171 |
+
groups=1,
|
172 |
+
)
|
173 |
+
|
174 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
175 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
176 |
+
|
177 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
178 |
+
|
179 |
+
def forward(self, *xs):
|
180 |
+
"""Forward pass.
|
181 |
+
Returns:
|
182 |
+
tensor: output
|
183 |
+
"""
|
184 |
+
output = xs[0]
|
185 |
+
|
186 |
+
if len(xs) == 2:
|
187 |
+
res = self.resConfUnit1(xs[1])
|
188 |
+
if self.width_ratio != 1:
|
189 |
+
res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
|
190 |
+
|
191 |
+
output = self.skip_add.add(output, res)
|
192 |
+
# output += res
|
193 |
+
|
194 |
+
output = self.resConfUnit2(output)
|
195 |
+
|
196 |
+
if self.width_ratio != 1:
|
197 |
+
# and output.shape[3] < self.width_ratio * output.shape[2]
|
198 |
+
#size=(image.shape[])
|
199 |
+
if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
|
200 |
+
shape = 3 * output.shape[3]
|
201 |
+
else:
|
202 |
+
shape = int(self.width_ratio * 2 * output.shape[2])
|
203 |
+
output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
|
204 |
+
else:
|
205 |
+
output = nn.functional.interpolate(output, scale_factor=2,
|
206 |
+
mode="bilinear", align_corners=self.align_corners)
|
207 |
+
output = self.out_conv(output)
|
208 |
+
return output
|
209 |
+
|
210 |
+
def make_fusion_block(features, use_bn, width_ratio=1):
|
211 |
+
return FeatureFusionBlock_custom(
|
212 |
+
features,
|
213 |
+
nn.ReLU(False),
|
214 |
+
deconv=False,
|
215 |
+
bn=use_bn,
|
216 |
+
expand=False,
|
217 |
+
align_corners=True,
|
218 |
+
width_ratio=width_ratio,
|
219 |
+
)
|
220 |
+
|
221 |
+
class Interpolate(nn.Module):
|
222 |
+
"""Interpolation module."""
|
223 |
+
|
224 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
225 |
+
"""Init.
|
226 |
+
Args:
|
227 |
+
scale_factor (float): scaling
|
228 |
+
mode (str): interpolation mode
|
229 |
+
"""
|
230 |
+
super(Interpolate, self).__init__()
|
231 |
+
|
232 |
+
self.interp = nn.functional.interpolate
|
233 |
+
self.scale_factor = scale_factor
|
234 |
+
self.mode = mode
|
235 |
+
self.align_corners = align_corners
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
"""Forward pass.
|
239 |
+
Args:
|
240 |
+
x (tensor): input
|
241 |
+
Returns:
|
242 |
+
tensor: interpolated data
|
243 |
+
"""
|
244 |
+
|
245 |
+
x = self.interp(
|
246 |
+
x,
|
247 |
+
scale_factor=self.scale_factor,
|
248 |
+
mode=self.mode,
|
249 |
+
align_corners=self.align_corners,
|
250 |
+
)
|
251 |
+
|
252 |
+
return x
|
253 |
+
|
254 |
+
class DPTOutputAdapter(nn.Module):
|
255 |
+
"""DPT output adapter.
|
256 |
+
|
257 |
+
:param num_cahnnels: Number of output channels
|
258 |
+
:param stride_level: stride level compared to the full-sized image.
|
259 |
+
E.g. 4 for 1/4th the size of the image.
|
260 |
+
:param patch_size_full: Int or tuple of the patch size over the full image size.
|
261 |
+
Patch size for smaller inputs will be computed accordingly.
|
262 |
+
:param hooks: Index of intermediate layers
|
263 |
+
:param layer_dims: Dimension of intermediate layers
|
264 |
+
:param feature_dim: Feature dimension
|
265 |
+
:param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
|
266 |
+
:param use_bn: If set to True, activates batch norm
|
267 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
268 |
+
"""
|
269 |
+
|
270 |
+
def __init__(self,
|
271 |
+
num_channels: int = 1,
|
272 |
+
stride_level: int = 1,
|
273 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
274 |
+
main_tasks: Iterable[str] = ('rgb',),
|
275 |
+
hooks: List[int] = [2, 5, 8, 11],
|
276 |
+
layer_dims: List[int] = [96, 192, 384, 768],
|
277 |
+
feature_dim: int = 256,
|
278 |
+
last_dim: int = 32,
|
279 |
+
use_bn: bool = False,
|
280 |
+
dim_tokens_enc: Optional[int] = None,
|
281 |
+
head_type: str = 'regression',
|
282 |
+
output_width_ratio=1,
|
283 |
+
**kwargs):
|
284 |
+
super().__init__()
|
285 |
+
self.num_channels = num_channels
|
286 |
+
self.stride_level = stride_level
|
287 |
+
self.patch_size = pair(patch_size)
|
288 |
+
self.main_tasks = main_tasks
|
289 |
+
self.hooks = hooks
|
290 |
+
self.layer_dims = layer_dims
|
291 |
+
self.feature_dim = feature_dim
|
292 |
+
self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
|
293 |
+
self.head_type = head_type
|
294 |
+
|
295 |
+
# Actual patch height and width, taking into account stride of input
|
296 |
+
self.P_H = max(1, self.patch_size[0] // stride_level)
|
297 |
+
self.P_W = max(1, self.patch_size[1] // stride_level)
|
298 |
+
|
299 |
+
self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
|
300 |
+
self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
301 |
+
self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
302 |
+
self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
303 |
+
self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
304 |
+
|
305 |
+
if self.head_type == 'regression':
|
306 |
+
# The "DPTDepthModel" head
|
307 |
+
self.head = nn.Sequential(
|
308 |
+
nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
|
309 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
310 |
+
nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
|
311 |
+
nn.ReLU(True),
|
312 |
+
nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
|
313 |
+
)
|
314 |
+
elif self.head_type == 'semseg':
|
315 |
+
# The "DPTSegmentationModel" head
|
316 |
+
self.head = nn.Sequential(
|
317 |
+
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
|
318 |
+
nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
|
319 |
+
nn.ReLU(True),
|
320 |
+
nn.Dropout(0.1, False),
|
321 |
+
nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
|
322 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
raise ValueError('DPT head_type must be "regression" or "semseg".')
|
326 |
+
|
327 |
+
if self.dim_tokens_enc is not None:
|
328 |
+
self.init(dim_tokens_enc=dim_tokens_enc)
|
329 |
+
|
330 |
+
def init(self, dim_tokens_enc=768):
|
331 |
+
"""
|
332 |
+
Initialize parts of decoder that are dependent on dimension of encoder tokens.
|
333 |
+
Should be called when setting up MultiMAE.
|
334 |
+
|
335 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
336 |
+
"""
|
337 |
+
#print(dim_tokens_enc)
|
338 |
+
|
339 |
+
# Set up activation postprocessing layers
|
340 |
+
if isinstance(dim_tokens_enc, int):
|
341 |
+
dim_tokens_enc = 4 * [dim_tokens_enc]
|
342 |
+
|
343 |
+
self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
|
344 |
+
|
345 |
+
self.act_1_postprocess = nn.Sequential(
|
346 |
+
nn.Conv2d(
|
347 |
+
in_channels=self.dim_tokens_enc[0],
|
348 |
+
out_channels=self.layer_dims[0],
|
349 |
+
kernel_size=1, stride=1, padding=0,
|
350 |
+
),
|
351 |
+
nn.ConvTranspose2d(
|
352 |
+
in_channels=self.layer_dims[0],
|
353 |
+
out_channels=self.layer_dims[0],
|
354 |
+
kernel_size=4, stride=4, padding=0,
|
355 |
+
bias=True, dilation=1, groups=1,
|
356 |
+
)
|
357 |
+
)
|
358 |
+
|
359 |
+
self.act_2_postprocess = nn.Sequential(
|
360 |
+
nn.Conv2d(
|
361 |
+
in_channels=self.dim_tokens_enc[1],
|
362 |
+
out_channels=self.layer_dims[1],
|
363 |
+
kernel_size=1, stride=1, padding=0,
|
364 |
+
),
|
365 |
+
nn.ConvTranspose2d(
|
366 |
+
in_channels=self.layer_dims[1],
|
367 |
+
out_channels=self.layer_dims[1],
|
368 |
+
kernel_size=2, stride=2, padding=0,
|
369 |
+
bias=True, dilation=1, groups=1,
|
370 |
+
)
|
371 |
+
)
|
372 |
+
|
373 |
+
self.act_3_postprocess = nn.Sequential(
|
374 |
+
nn.Conv2d(
|
375 |
+
in_channels=self.dim_tokens_enc[2],
|
376 |
+
out_channels=self.layer_dims[2],
|
377 |
+
kernel_size=1, stride=1, padding=0,
|
378 |
+
)
|
379 |
+
)
|
380 |
+
|
381 |
+
self.act_4_postprocess = nn.Sequential(
|
382 |
+
nn.Conv2d(
|
383 |
+
in_channels=self.dim_tokens_enc[3],
|
384 |
+
out_channels=self.layer_dims[3],
|
385 |
+
kernel_size=1, stride=1, padding=0,
|
386 |
+
),
|
387 |
+
nn.Conv2d(
|
388 |
+
in_channels=self.layer_dims[3],
|
389 |
+
out_channels=self.layer_dims[3],
|
390 |
+
kernel_size=3, stride=2, padding=1,
|
391 |
+
)
|
392 |
+
)
|
393 |
+
|
394 |
+
self.act_postprocess = nn.ModuleList([
|
395 |
+
self.act_1_postprocess,
|
396 |
+
self.act_2_postprocess,
|
397 |
+
self.act_3_postprocess,
|
398 |
+
self.act_4_postprocess
|
399 |
+
])
|
400 |
+
|
401 |
+
def adapt_tokens(self, encoder_tokens):
|
402 |
+
# Adapt tokens
|
403 |
+
x = []
|
404 |
+
x.append(encoder_tokens[:, :])
|
405 |
+
x = torch.cat(x, dim=-1)
|
406 |
+
return x
|
407 |
+
|
408 |
+
def forward(self, encoder_tokens: List[torch.Tensor], image_size):
|
409 |
+
#input_info: Dict):
|
410 |
+
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
|
411 |
+
H, W = image_size
|
412 |
+
|
413 |
+
# Number of patches in height and width
|
414 |
+
N_H = H // (self.stride_level * self.P_H)
|
415 |
+
N_W = W // (self.stride_level * self.P_W)
|
416 |
+
|
417 |
+
# Hook decoder onto 4 layers from specified ViT layers
|
418 |
+
layers = [encoder_tokens[hook] for hook in self.hooks]
|
419 |
+
|
420 |
+
# Extract only task-relevant tokens and ignore global tokens.
|
421 |
+
layers = [self.adapt_tokens(l) for l in layers]
|
422 |
+
|
423 |
+
# Reshape tokens to spatial representation
|
424 |
+
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
|
425 |
+
|
426 |
+
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
|
427 |
+
# Project layers to chosen feature dim
|
428 |
+
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
|
429 |
+
|
430 |
+
# Fuse layers using refinement stages
|
431 |
+
path_4 = self.scratch.refinenet4(layers[3])
|
432 |
+
path_3 = self.scratch.refinenet3(path_4, layers[2])
|
433 |
+
path_2 = self.scratch.refinenet2(path_3, layers[1])
|
434 |
+
path_1 = self.scratch.refinenet1(path_2, layers[0])
|
435 |
+
|
436 |
+
# Output head
|
437 |
+
out = self.head(path_1)
|
438 |
+
|
439 |
+
return out
|
440 |
+
|
441 |
+
class DPTOutputAdapter_fix(DPTOutputAdapter):
|
442 |
+
"""
|
443 |
+
Adapt croco's DPTOutputAdapter implementation for dust3r:
|
444 |
+
remove duplicated weigths, and fix forward for dust3r
|
445 |
+
"""
|
446 |
+
|
447 |
+
def init(self, dim_tokens_enc=768,**kwargs):
|
448 |
+
super().init(dim_tokens_enc,**kwargs)
|
449 |
+
# these are duplicated weights
|
450 |
+
del self.act_1_postprocess
|
451 |
+
del self.act_2_postprocess
|
452 |
+
del self.act_3_postprocess
|
453 |
+
del self.act_4_postprocess
|
454 |
+
|
455 |
+
def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
|
456 |
+
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
|
457 |
+
# H, W = input_info['image_size']
|
458 |
+
image_size = self.image_size if image_size is None else image_size
|
459 |
+
H, W = image_size
|
460 |
+
# Number of patches in height and width
|
461 |
+
N_H = H // (self.stride_level * self.P_H)
|
462 |
+
N_W = W // (self.stride_level * self.P_W)
|
463 |
+
|
464 |
+
# Hook decoder onto 4 layers from specified ViT layers
|
465 |
+
layers = [encoder_tokens[hook] for hook in self.hooks]
|
466 |
+
|
467 |
+
# Extract only task-relevant tokens and ignore global tokens.
|
468 |
+
layers = [self.adapt_tokens(l) for l in layers]
|
469 |
+
|
470 |
+
# Reshape tokens to spatial representation
|
471 |
+
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
|
472 |
+
|
473 |
+
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
|
474 |
+
# Project layers to chosen feature dim
|
475 |
+
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
|
476 |
+
|
477 |
+
# Fuse layers using refinement stages
|
478 |
+
path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
|
479 |
+
path_3 = self.scratch.refinenet3(path_4, layers[2])
|
480 |
+
path_2 = self.scratch.refinenet2(path_3, layers[1])
|
481 |
+
path_1 = self.scratch.refinenet1(path_2, layers[0])
|
482 |
+
|
483 |
+
# Output head
|
484 |
+
out = self.head(path_1)
|
485 |
+
return out
|
486 |
+
|
487 |
+
|
488 |
+
class PixelwiseTaskWithDPT(nn.Module):
|
489 |
+
""" DPT module for dust3r, can return 3D points + confidence for all pixels"""
|
490 |
+
|
491 |
+
def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
|
492 |
+
output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, classifier_mode=None, **kwargs):
|
493 |
+
super(PixelwiseTaskWithDPT, self).__init__()
|
494 |
+
self.return_all_layers = True # backbone needs to return all layers
|
495 |
+
self.postprocess = postprocess
|
496 |
+
self.depth_mode = depth_mode
|
497 |
+
self.conf_mode = conf_mode
|
498 |
+
self.classifier_mode = classifier_mode
|
499 |
+
|
500 |
+
assert n_cls_token == 0, "Not implemented"
|
501 |
+
dpt_args = dict(output_width_ratio=output_width_ratio,
|
502 |
+
num_channels=num_channels,
|
503 |
+
**kwargs)
|
504 |
+
if hooks_idx is not None:
|
505 |
+
dpt_args.update(hooks=hooks_idx)
|
506 |
+
self.dpt = DPTOutputAdapter_fix(**dpt_args)
|
507 |
+
dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
|
508 |
+
self.dpt.init(**dpt_init_args)
|
509 |
+
|
510 |
+
def forward(self, x, img_info):
|
511 |
+
out = self.dpt(x, image_size=(img_info[0], img_info[1]))
|
512 |
+
if self.postprocess:
|
513 |
+
out = self.postprocess(out, self.depth_mode, self.conf_mode,self.classifier_mode)
|
514 |
+
return out
|
515 |
+
|
516 |
+
def create_dpt_head(net, has_conf=False):
|
517 |
+
"""
|
518 |
+
return PixelwiseTaskWithDPT for given net params
|
519 |
+
"""
|
520 |
+
assert net.dec_depth > 9
|
521 |
+
l2 = net.dec_depth - 1
|
522 |
+
feature_dim = 256
|
523 |
+
last_dim = feature_dim//2
|
524 |
+
out_nchan = 3
|
525 |
+
ed = net.enc_embed_dim
|
526 |
+
dd = net.dec_embed_dim
|
527 |
+
return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
|
528 |
+
feature_dim=feature_dim,
|
529 |
+
last_dim=last_dim,
|
530 |
+
hooks_idx=[0, l2*2//4, l2*3//4, l2],
|
531 |
+
dim_tokens=[ed, dd, dd, dd],
|
532 |
+
postprocess=postprocess,
|
533 |
+
depth_mode=net.depth_mode,
|
534 |
+
conf_mode=net.conf_mode,
|
535 |
+
head_type='regression',
|
536 |
+
patch_size=net.patch_size)
|
537 |
+
|
538 |
+
def create_dpt_head_depth(net, has_conf=False):
|
539 |
+
"""
|
540 |
+
return PixelwiseTaskWithDPT for given net params
|
541 |
+
"""
|
542 |
+
assert net.dec_depth > 9
|
543 |
+
l2 = net.dec_depth - 1
|
544 |
+
feature_dim = 256
|
545 |
+
last_dim = feature_dim//2
|
546 |
+
out_nchan = 1
|
547 |
+
ed = net.enc_embed_dim
|
548 |
+
dd = net.dec_embed_dim
|
549 |
+
return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
|
550 |
+
feature_dim=feature_dim,
|
551 |
+
last_dim=last_dim,
|
552 |
+
hooks_idx=[0, l2*2//4, l2*3//4, l2],
|
553 |
+
dim_tokens=[ed, dd, dd, dd],
|
554 |
+
postprocess=postprocess,
|
555 |
+
depth_mode=net.depth_mode,
|
556 |
+
conf_mode=net.conf_mode,
|
557 |
+
head_type='regression',
|
558 |
+
patch_size=net.patch_size)
|
559 |
+
|
560 |
+
|
561 |
+
def create_dpt_head_mask(net, has_conf=False):
|
562 |
+
"""
|
563 |
+
return PixelwiseTaskWithDPT for given net params
|
564 |
+
"""
|
565 |
+
assert net.dec_depth > 9
|
566 |
+
l2 = net.dec_depth - 1
|
567 |
+
feature_dim = 256
|
568 |
+
last_dim = feature_dim//2
|
569 |
+
out_nchan = 3
|
570 |
+
ed = net.enc_embed_dim
|
571 |
+
dd = net.dec_embed_dim
|
572 |
+
return PixelwiseTaskWithDPT(num_channels=1 + has_conf,
|
573 |
+
feature_dim=feature_dim,
|
574 |
+
last_dim=last_dim,
|
575 |
+
hooks_idx=[0, l2*2//4, l2*3//4, l2],
|
576 |
+
dim_tokens=[ed, dd, dd, dd],
|
577 |
+
postprocess=postprocess,
|
578 |
+
depth_mode=net.depth_mode,
|
579 |
+
conf_mode=net.conf_mode,
|
580 |
+
classifier_mode=net.classifier_mode,
|
581 |
+
head_type='regression',
|
582 |
+
patch_size=net.patch_size)
|
models/heads/linear_head.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .postprocess import postprocess
|
5 |
+
|
6 |
+
class LinearPts3d (nn.Module):
|
7 |
+
"""
|
8 |
+
Linear head for dust3r
|
9 |
+
Each token outputs: - 16x16 3D points (+ confidence)
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, net, has_conf=False,mode='pts3d'):
|
13 |
+
super().__init__()
|
14 |
+
self.patch_size = net.patch_size
|
15 |
+
self.depth_mode = net.depth_mode
|
16 |
+
self.conf_mode = net.conf_mode
|
17 |
+
self.has_conf = has_conf
|
18 |
+
self.mode = mode
|
19 |
+
self.classifier_mode = None
|
20 |
+
if self.mode == 'pts3d':
|
21 |
+
self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
|
22 |
+
elif self.mode == 'depth':
|
23 |
+
self.proj = nn.Linear(net.dec_embed_dim, (1 + has_conf)*self.patch_size**2)
|
24 |
+
elif self.mode == 'classifier':
|
25 |
+
self.proj = nn.Linear(net.dec_embed_dim, (1 + has_conf)*self.patch_size**2)
|
26 |
+
self.classifier_mode = net.classifier_mode
|
27 |
+
|
28 |
+
def setup(self, croconet):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def forward(self, decout, img_shape):
|
32 |
+
H, W = img_shape
|
33 |
+
tokens = decout[-1]
|
34 |
+
B, S, D = tokens.shape
|
35 |
+
|
36 |
+
# extract 3D points
|
37 |
+
feat = self.proj(tokens) # B,S,D
|
38 |
+
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
|
39 |
+
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
|
40 |
+
|
41 |
+
# permute + norm depth
|
42 |
+
return postprocess(feat, self.depth_mode, self.conf_mode,self.classifier_mode)
|
models/heads/postprocess.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def postprocess(out, depth_mode, conf_mode,classifier_mode=None):
|
4 |
+
"""
|
5 |
+
extract 3D points/confidence from prediction head output
|
6 |
+
"""
|
7 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
|
8 |
+
if classifier_mode is None:
|
9 |
+
if fmap.shape[-1] == 4:
|
10 |
+
res = dict(pointmaps=reg_dense_pts3d(fmap[:, :, :, :-1], mode=depth_mode))
|
11 |
+
else:
|
12 |
+
res = dict(depths=reg_dense_depth(fmap[:, :, :, 0], mode=depth_mode))
|
13 |
+
if conf_mode is not None:
|
14 |
+
res['conf_pointmaps'] = reg_dense_conf(fmap[:, :, :, -1], mode=conf_mode)
|
15 |
+
else:
|
16 |
+
res = dict(classifier=reg_dense_classifier(fmap[:, :, :, 0], mode=classifier_mode))
|
17 |
+
if conf_mode is not None:
|
18 |
+
res['conf_classifier'] = reg_dense_conf(fmap[:, :, :, 1], mode=conf_mode)
|
19 |
+
|
20 |
+
return res
|
21 |
+
|
22 |
+
def reg_dense_classifier(x, mode):
|
23 |
+
"""
|
24 |
+
extract classifier from prediction head output
|
25 |
+
"""
|
26 |
+
mode, vmin, vmax = mode
|
27 |
+
#return torch.sigmoid(x)
|
28 |
+
return x
|
29 |
+
|
30 |
+
def reg_dense_depth(x, mode):
|
31 |
+
"""
|
32 |
+
extract depth from prediction head output
|
33 |
+
"""
|
34 |
+
mode, vmin, vmax = mode
|
35 |
+
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
|
36 |
+
assert no_bounds
|
37 |
+
if mode == 'linear':
|
38 |
+
return x
|
39 |
+
elif mode == 'square':
|
40 |
+
return x.square().clip(min=vmin, max=vmax)
|
41 |
+
elif mode == 'exp':
|
42 |
+
return torch.exp(x).clip(min=vmin, max=vmax)
|
43 |
+
else:
|
44 |
+
raise ValueError(f'bad {mode=}')
|
45 |
+
|
46 |
+
def reg_dense_pts3d(xyz, mode):
|
47 |
+
"""
|
48 |
+
extract 3D points from prediction head output
|
49 |
+
"""
|
50 |
+
mode, vmin, vmax = mode
|
51 |
+
|
52 |
+
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
|
53 |
+
assert no_bounds
|
54 |
+
|
55 |
+
if mode == 'linear':
|
56 |
+
if no_bounds:
|
57 |
+
return xyz # [-inf, +inf]
|
58 |
+
return xyz.clip(min=vmin, max=vmax)
|
59 |
+
|
60 |
+
# distance to origin
|
61 |
+
d = xyz.norm(dim=-1, keepdim=True)
|
62 |
+
xyz = xyz / d.clip(min=1e-8)
|
63 |
+
if mode == 'square':
|
64 |
+
return xyz * d.square()
|
65 |
+
|
66 |
+
if mode == 'exp':
|
67 |
+
return xyz * torch.expm1(d)
|
68 |
+
raise ValueError(f'bad {mode=}')
|
69 |
+
|
70 |
+
def reg_dense_conf(x, mode):
|
71 |
+
"""
|
72 |
+
extract confidence from prediction head output
|
73 |
+
"""
|
74 |
+
mode, vmin, vmax = mode
|
75 |
+
if mode == 'exp':
|
76 |
+
return vmin + x.exp().clip(max=vmax-vmin)
|
77 |
+
if mode == 'sigmoid':
|
78 |
+
return (vmax - vmin) * torch.sigmoid(x) + vmin
|
79 |
+
raise ValueError(f'bad {mode=}')
|
80 |
+
|
models/losses.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bb = breakpoint
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import copy
|
5 |
+
from utils.geometry import normalize_pointcloud
|
6 |
+
|
7 |
+
class Criterion (nn.Module):
|
8 |
+
def __init__(self, criterion=None):
|
9 |
+
super().__init__()
|
10 |
+
self.criterion = copy.deepcopy(criterion)
|
11 |
+
|
12 |
+
def get_name(self):
|
13 |
+
return f'{type(self).__name__}({self.criterion})'
|
14 |
+
|
15 |
+
class CrocoLoss (nn.Module):
|
16 |
+
def __init__(self,mode='vanilla',eps=1e-4):
|
17 |
+
super().__init__()
|
18 |
+
self.mode = mode
|
19 |
+
def get_name(self):
|
20 |
+
return f'CrocoLoss({self.mode})'
|
21 |
+
|
22 |
+
def forward(self, pred, gt, **kw):
|
23 |
+
pred_pts = pred['pointmaps']
|
24 |
+
conf = pred['conf']
|
25 |
+
|
26 |
+
if self.mode == 'vanilla':
|
27 |
+
loss = torch.abs(gt-pred_pts)/(torch.exp(conf)) + conf
|
28 |
+
elif self.mode == 'bounded_1':
|
29 |
+
a=0.25
|
30 |
+
b=4.
|
31 |
+
conf = (b-a)*torch.sigmoid(conf) + a
|
32 |
+
loss = torch.abs(gt-pred_pts)/(conf) + torch.log(conf)
|
33 |
+
elif self.mode == 'bounded_2':
|
34 |
+
a = 3.0
|
35 |
+
b = 3.0
|
36 |
+
conf = 2*a * (torch.sigmoid(conf/b)-0.5)
|
37 |
+
loss = torch.abs(gt-pred_pts)/torch.exp(conf) + conf
|
38 |
+
return loss.mean()
|
39 |
+
|
40 |
+
class SMDLoss (nn.Module):
|
41 |
+
def __init__(self,raw_loss,mode='linear'):
|
42 |
+
super().__init__()
|
43 |
+
self.mode = mode
|
44 |
+
self.raw_loss = raw_loss
|
45 |
+
def get_name(self):
|
46 |
+
return f'SMDLoss({self.raw_loss},{self.mode})'
|
47 |
+
|
48 |
+
def forward(self, pred, gt,eps, **kw):
|
49 |
+
p_gt = compute_probs(pred,gt,eps=eps)
|
50 |
+
# filtering out nan values
|
51 |
+
loss = self.raw_loss(p_gt)
|
52 |
+
loss_mask = ~torch.isnan(p_gt) & (loss != torch.inf).bool()
|
53 |
+
loss = loss[loss_mask]
|
54 |
+
return loss.mean()
|
55 |
+
|
56 |
+
# https://github.com/naver/dust3r/blob/c9e9336a6ba7c1f1873f9295852cea6dffaf770d/dust3r/losses.py#L197
|
57 |
+
class ConfLoss (nn.Module):
|
58 |
+
""" Weighted regression by learned confidence.
|
59 |
+
Assuming the input pixel_loss is a pixel-level regression loss.
|
60 |
+
|
61 |
+
Principle:
|
62 |
+
high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
|
63 |
+
low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
|
64 |
+
|
65 |
+
alpha: hyperparameter
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, raw_loss, alpha=0.2,skip_conf=False):
|
69 |
+
super().__init__()
|
70 |
+
assert alpha > 0
|
71 |
+
self.alpha = alpha
|
72 |
+
self.raw_loss = raw_loss
|
73 |
+
self.skip_conf = skip_conf
|
74 |
+
|
75 |
+
def get_name(self):
|
76 |
+
return f'ConfLoss({self.raw_loss})'
|
77 |
+
|
78 |
+
def get_conf_log(self, x):
|
79 |
+
return x, torch.log(x)
|
80 |
+
|
81 |
+
def forward(self, pred, gt,conf, **kw):
|
82 |
+
# compute per-pixel loss
|
83 |
+
loss = self.raw_loss(gt, pred, **kw)
|
84 |
+
# weight by confidence
|
85 |
+
if not self.skip_conf:
|
86 |
+
conf, log_conf = self.get_conf_log(conf)
|
87 |
+
conf_loss = loss * conf - self.alpha * log_conf
|
88 |
+
## average + nan protection (in case of no valid pixels at all)
|
89 |
+
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
|
90 |
+
return conf_loss
|
91 |
+
else:
|
92 |
+
return loss.mean()
|
93 |
+
|
94 |
+
|
95 |
+
class BCELoss(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
def get_name(self):
|
100 |
+
return f'BCELoss()'
|
101 |
+
|
102 |
+
def forward(self, gt, pred):
|
103 |
+
# return torch.nn.functional.binary_cross_entropy(pred, gt)
|
104 |
+
return torch.nn.functional.binary_cross_entropy_with_logits(pred, gt)
|
105 |
+
|
106 |
+
class ClassifierLoss(nn.Module):
|
107 |
+
def __init__(self,criterion):
|
108 |
+
super().__init__()
|
109 |
+
self.criterion = criterion
|
110 |
+
|
111 |
+
def get_name(self):
|
112 |
+
return f'ClassifierLoss({self.criterion})'
|
113 |
+
|
114 |
+
def forward(self, pred, gt):
|
115 |
+
return self.criterion(pred, gt)
|
116 |
+
|
117 |
+
class BaseCriterion(nn.Module):
|
118 |
+
def __init__(self, reduction='none'):
|
119 |
+
super().__init__()
|
120 |
+
self.reduction = reduction
|
121 |
+
|
122 |
+
class NLLLoss (BaseCriterion):
|
123 |
+
""" Negative log likelihood loss """
|
124 |
+
def forward(self, pred):
|
125 |
+
# assuming the pred is already a log (for stability sake)
|
126 |
+
return -pred
|
127 |
+
#return -torch.log(pred)
|
128 |
+
|
129 |
+
class LLoss (BaseCriterion):
|
130 |
+
""" L-norm loss
|
131 |
+
"""
|
132 |
+
def forward(self, a, b):
|
133 |
+
assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
|
134 |
+
dist = self.distance(a, b)
|
135 |
+
assert dist.ndim == a.ndim - 1 # one dimension less
|
136 |
+
if self.reduction == 'none':
|
137 |
+
return dist
|
138 |
+
if self.reduction == 'sum':
|
139 |
+
return dist.sum()
|
140 |
+
if self.reduction == 'mean':
|
141 |
+
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
|
142 |
+
raise ValueError(f'bad {self.reduction=} mode')
|
143 |
+
|
144 |
+
def distance(self, a, b):
|
145 |
+
raise NotImplementedError()
|
146 |
+
|
147 |
+
class L21Loss (LLoss):
|
148 |
+
""" Euclidean distance between 3d points """
|
149 |
+
|
150 |
+
def distance(self, a, b):
|
151 |
+
return torch.norm(a - b, dim=-1)
|
152 |
+
|
153 |
+
L21 = L21Loss()
|
154 |
+
|
155 |
+
def apply_log_to_norm(xyz):
|
156 |
+
d = xyz.norm(dim=-1, keepdim=True)
|
157 |
+
xyz = xyz / d.clip(min=1e-8)
|
158 |
+
xyz = xyz * torch.log1p(d)
|
159 |
+
return xyz
|
160 |
+
|
161 |
+
class DepthCompletion (Criterion):
|
162 |
+
def __init__(self, criterion, classifier_criterion=None,norm_mode='?None', loss_in_log=False,device='cuda',lambda_classifier=1.0):
|
163 |
+
super().__init__(criterion)
|
164 |
+
self.criterion.reduction = 'none'
|
165 |
+
self.loss_in_log = loss_in_log
|
166 |
+
self.device = device
|
167 |
+
self.lambda_classifier = lambda_classifier
|
168 |
+
self.classifier_criterion = classifier_criterion
|
169 |
+
|
170 |
+
if norm_mode.startswith('?'):
|
171 |
+
# do no norm pts from metric scale datasets
|
172 |
+
self.norm_all = False
|
173 |
+
self.norm_mode = norm_mode[1:]
|
174 |
+
else:
|
175 |
+
self.norm_all = True
|
176 |
+
self.norm_mode = norm_mode
|
177 |
+
|
178 |
+
def forward(self, pred_dict, gt_dict,**kw):
|
179 |
+
gt_depths = gt_dict['depths']
|
180 |
+
pred_depths = pred_dict['depths']
|
181 |
+
gt_masks = gt_dict['valid_masks']
|
182 |
+
if gt_masks.sum() == 0:
|
183 |
+
return None
|
184 |
+
else:
|
185 |
+
gt_depths_masked = gt_depths[gt_masks].view(-1,1)
|
186 |
+
pred_depths_masked = pred_depths[gt_masks].view(-1,1)
|
187 |
+
# this is a loss on the points on the objects
|
188 |
+
loss_dict = {'loss_points':self.criterion(pred_depths_masked, gt_depths_masked,pred_dict['conf_pointmaps'][gt_masks])}
|
189 |
+
# loss on predicting a mask for the points on the objects
|
190 |
+
if 'classifier' in pred_dict and self.classifier_criterion is not None:
|
191 |
+
loss_dict['loss_classifier'] = self.classifier_criterion(pred_dict['classifier'], gt_dict['valid_masks'].float(),pred_dict['conf_classifier'])
|
192 |
+
loss_dict['loss'] = loss_dict['loss_points'] + self.lambda_classifier * loss_dict['loss_classifier']
|
193 |
+
else:
|
194 |
+
loss_dict['loss'] = loss_dict['loss_points']
|
195 |
+
|
196 |
+
return loss_dict
|
197 |
+
|
198 |
+
|
199 |
+
class RayCompletion (Criterion):
|
200 |
+
def __init__(self, criterion, classifier_criterion=None,norm_mode='?None', loss_in_log=False,device='cuda',lambda_classifier=1.0):
|
201 |
+
super().__init__(criterion)
|
202 |
+
self.criterion.reduction = 'none'
|
203 |
+
self.loss_in_log = loss_in_log
|
204 |
+
self.device = device
|
205 |
+
self.lambda_classifier = lambda_classifier
|
206 |
+
self.classifier_criterion = classifier_criterion
|
207 |
+
|
208 |
+
if norm_mode.startswith('?'):
|
209 |
+
# do no norm pts from metric scale datasets
|
210 |
+
self.norm_all = False
|
211 |
+
self.norm_mode = norm_mode[1:]
|
212 |
+
else:
|
213 |
+
self.norm_all = True
|
214 |
+
self.norm_mode = norm_mode
|
215 |
+
|
216 |
+
def get_all_pts3d(self, gt_dict, pred_dict):
|
217 |
+
gt_pts1 = gt_dict['pointmaps']
|
218 |
+
#gt_pts_context = gt_dict['pointmaps_context'][:,0] # we use the first camera given as input for normalization, in our current case that's the only cam
|
219 |
+
if 'pointmaps' in pred_dict:
|
220 |
+
pr_pts1 = pred_dict['pointmaps']
|
221 |
+
else:
|
222 |
+
pr_pts1 = None
|
223 |
+
mask = gt_dict['valid_masks'].clone()
|
224 |
+
# normalize 3d points
|
225 |
+
norm_factor = None
|
226 |
+
|
227 |
+
return gt_pts1, pr_pts1, mask, norm_factor
|
228 |
+
|
229 |
+
def forward(self, pred_dict, gt_dict, eps=None,**kw):
|
230 |
+
gt_pts1, pred_pts1, mask, norm_factor = \
|
231 |
+
self.get_all_pts3d(gt_dict, pred_dict, **kw)
|
232 |
+
if mask.sum() == 0:
|
233 |
+
return None
|
234 |
+
else:
|
235 |
+
mask_repeated = mask.unsqueeze(-1).repeat(1,1,1,3)
|
236 |
+
if norm_factor is not None:
|
237 |
+
pred_pts1 = pred_pts1 / norm_factor
|
238 |
+
gt_pts1 = gt_pts1 / norm_factor
|
239 |
+
|
240 |
+
pred_pts1 = pred_pts1[mask_repeated].reshape(-1,3)
|
241 |
+
gt_pts1 = gt_pts1[mask_repeated].reshape(-1,3)
|
242 |
+
|
243 |
+
if self.loss_in_log and self.loss_in_log != 'before':
|
244 |
+
# this only make sense when depth_mode == 'exp'
|
245 |
+
pred_pts1 = apply_log_to_norm(pred_pts1)
|
246 |
+
gt_pts1 = apply_log_to_norm(gt_pts1)
|
247 |
+
|
248 |
+
# this is a loss on the points on the objects
|
249 |
+
loss_dict = {'loss_points':self.criterion(pred_pts1, gt_pts1,pred_dict['conf_pointmaps'][mask])}
|
250 |
+
# loss on predicting a mask for the points on the objects
|
251 |
+
if 'classifier' in pred_dict and self.classifier_criterion is not None:
|
252 |
+
loss_dict['loss_classifier'] = self.classifier_criterion(pred_dict['classifier'], gt_dict['valid_masks'].float(),pred_dict['conf_classifier'])
|
253 |
+
loss_dict['loss'] = loss_dict['loss_points'] + self.lambda_classifier * loss_dict['loss_classifier']
|
254 |
+
else:
|
255 |
+
loss_dict['loss'] = loss_dict['loss_points']
|
256 |
+
|
257 |
+
return loss_dict
|
models/pos_embed.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# Position embedding utils
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# --------------------------------------------------------
|
13 |
+
# 2D sine-cosine position embedding
|
14 |
+
# References:
|
15 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
16 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
17 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
18 |
+
# --------------------------------------------------------
|
19 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
|
20 |
+
"""
|
21 |
+
grid_size: int of the grid height and width
|
22 |
+
return:
|
23 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
24 |
+
"""
|
25 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
26 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
27 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
28 |
+
grid = np.stack(grid, axis=0)
|
29 |
+
|
30 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
31 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
32 |
+
if n_cls_token>0:
|
33 |
+
pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
|
34 |
+
return pos_embed
|
35 |
+
|
36 |
+
|
37 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
38 |
+
assert embed_dim % 2 == 0
|
39 |
+
|
40 |
+
# use half of dimensions to encode grid_h
|
41 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
42 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
43 |
+
|
44 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
45 |
+
return emb
|
46 |
+
|
47 |
+
|
48 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
49 |
+
"""
|
50 |
+
embed_dim: output dimension for each position
|
51 |
+
pos: a list of positions to be encoded: size (M,)
|
52 |
+
out: (M, D)
|
53 |
+
"""
|
54 |
+
assert embed_dim % 2 == 0
|
55 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
56 |
+
omega /= embed_dim / 2.
|
57 |
+
omega = 1. / 10000**omega # (D/2,)
|
58 |
+
|
59 |
+
pos = pos.reshape(-1) # (M,)
|
60 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
61 |
+
|
62 |
+
emb_sin = np.sin(out) # (M, D/2)
|
63 |
+
emb_cos = np.cos(out) # (M, D/2)
|
64 |
+
|
65 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
66 |
+
return emb
|
67 |
+
|
68 |
+
|
69 |
+
# --------------------------------------------------------
|
70 |
+
# Interpolate position embeddings for high-resolution
|
71 |
+
# References:
|
72 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
74 |
+
# --------------------------------------------------------
|
75 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
76 |
+
if 'pos_embed' in checkpoint_model:
|
77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
79 |
+
num_patches = model.patch_embed.num_patches
|
80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
81 |
+
# height (== width) for the checkpoint position embedding
|
82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
83 |
+
# height (== width) for the new position embedding
|
84 |
+
new_size = int(num_patches ** 0.5)
|
85 |
+
# class_token and dist_token are kept unchanged
|
86 |
+
if orig_size != new_size:
|
87 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
88 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
89 |
+
# only the position tokens are interpolated
|
90 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
91 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
92 |
+
pos_tokens = torch.nn.functional.interpolate(
|
93 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
94 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
95 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
96 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
97 |
+
|
98 |
+
|
99 |
+
#----------------------------------------------------------
|
100 |
+
# RoPE2D: RoPE implementation in 2D
|
101 |
+
#----------------------------------------------------------
|
102 |
+
|
103 |
+
try:
|
104 |
+
from extensions.curope import cuRoPE2D
|
105 |
+
RoPE2D = cuRoPE2D
|
106 |
+
except ImportError:
|
107 |
+
print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
|
108 |
+
|
109 |
+
class RoPE2D(torch.nn.Module):
|
110 |
+
|
111 |
+
def __init__(self, freq=100.0, F0=1.0):
|
112 |
+
super().__init__()
|
113 |
+
self.base = freq
|
114 |
+
self.F0 = F0
|
115 |
+
self.cache = {}
|
116 |
+
|
117 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
118 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
119 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
120 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
121 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
122 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
123 |
+
cos = freqs.cos() # (Seq, Dim)
|
124 |
+
sin = freqs.sin()
|
125 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
126 |
+
return self.cache[D,seq_len,device,dtype]
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def rotate_half(x):
|
130 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
131 |
+
return torch.cat((-x2, x1), dim=-1)
|
132 |
+
|
133 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
134 |
+
assert pos1d.ndim==2
|
135 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
136 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
137 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
138 |
+
|
139 |
+
def forward(self, tokens, positions):
|
140 |
+
"""
|
141 |
+
input:
|
142 |
+
* tokens: batch_size x nheads x ntokens x dim
|
143 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
144 |
+
output:
|
145 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
146 |
+
"""
|
147 |
+
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
148 |
+
D = tokens.size(3) // 2
|
149 |
+
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
150 |
+
cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
151 |
+
# split features into two along the feature dimension, and apply rope1d on each half
|
152 |
+
y, x = tokens.chunk(2, dim=-1)
|
153 |
+
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
154 |
+
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
155 |
+
tokens = torch.cat((y, x), dim=-1)
|
156 |
+
return tokens
|
models/rayquery.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bb = breakpoint
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from models.blocks import DecoderBlock, Block, PatchEmbed, PositionGetter
|
5 |
+
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
|
6 |
+
from models.losses import *
|
7 |
+
from utils.geometry import center_pointmaps, compute_rays
|
8 |
+
from models.heads import head_factory
|
9 |
+
|
10 |
+
def init_weights(m):
|
11 |
+
if isinstance(m, nn.Linear):
|
12 |
+
# we use xavier_uniform following official JAX ViT:
|
13 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
14 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
15 |
+
nn.init.constant_(m.bias, 0)
|
16 |
+
elif isinstance(m, nn.LayerNorm):
|
17 |
+
if m.bias is not None:
|
18 |
+
nn.init.constant_(m.bias, 0)
|
19 |
+
if m.weight is not None:
|
20 |
+
nn.init.constant_(m.weight, 1.0)
|
21 |
+
elif isinstance(m, nn.Parameter):
|
22 |
+
nn.init.normal_(m, std=0.02)
|
23 |
+
|
24 |
+
class RayEncoder(nn.Module):
|
25 |
+
def __init__(self,
|
26 |
+
dim=256,
|
27 |
+
patch_size=8,
|
28 |
+
img_size=(128,128),
|
29 |
+
depth=3,
|
30 |
+
num_heads=4,
|
31 |
+
pos_embed='RoPE100',
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.img_size = img_size
|
35 |
+
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=2, embed_dim=dim)
|
36 |
+
self.dim = dim
|
37 |
+
if pos_embed.startswith('RoPE'):
|
38 |
+
freq = float(pos_embed[len('RoPE'):])
|
39 |
+
self.rope = RoPE2D(freq=freq)
|
40 |
+
else:
|
41 |
+
self.rope = None
|
42 |
+
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
|
43 |
+
self.initialize_weights()
|
44 |
+
|
45 |
+
def initialize_weights(self):
|
46 |
+
# patch embed
|
47 |
+
self.patch_embed._init_weights()
|
48 |
+
|
49 |
+
# linears and layer norms
|
50 |
+
self.apply(init_weights)
|
51 |
+
|
52 |
+
def forward(self, rays):
|
53 |
+
rays = rays.permute(0,3,1,2)
|
54 |
+
rays, pos = self.patch_embed(rays)
|
55 |
+
for blk in self.blocks:
|
56 |
+
rays = blk(rays, pos)
|
57 |
+
return rays, pos
|
58 |
+
|
59 |
+
class PointmapEncoder(nn.Module):
|
60 |
+
def __init__(self,
|
61 |
+
dim=256,
|
62 |
+
patch_size=8,
|
63 |
+
img_size=(128,128),
|
64 |
+
depth=3,
|
65 |
+
num_heads=4,
|
66 |
+
pos_embed='RoPE100',
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.img_size = img_size
|
70 |
+
self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=3, embed_dim=dim)
|
71 |
+
self.dim = dim
|
72 |
+
self.patch_size = patch_size
|
73 |
+
|
74 |
+
if pos_embed.startswith('RoPE'):
|
75 |
+
freq = float(pos_embed[len('RoPE'):])
|
76 |
+
self.rope = RoPE2D(freq=freq)
|
77 |
+
else:
|
78 |
+
self.rope = None
|
79 |
+
self.blocks = nn.ModuleList([Block(dim=dim, num_heads=num_heads,rope=self.rope) for _ in range(depth)])
|
80 |
+
self.masked_token = nn.Parameter(torch.randn(1,1,3))
|
81 |
+
self.initialize_weights()
|
82 |
+
|
83 |
+
def initialize_weights(self):
|
84 |
+
# patch embed
|
85 |
+
self.patch_embed._init_weights()
|
86 |
+
|
87 |
+
# linears and layer norms
|
88 |
+
self.apply(init_weights)
|
89 |
+
|
90 |
+
def forward(self, pointmaps,masks=None):
|
91 |
+
# replace masked points (not on object) with a learned token
|
92 |
+
pointmaps[~masks] = self.masked_token.to(pointmaps.dtype).to(pointmaps.device)
|
93 |
+
pointmaps = pointmaps.permute(0,3,1,2)
|
94 |
+
pointmaps, pos = self.patch_embed(pointmaps)
|
95 |
+
|
96 |
+
for blk in self.blocks:
|
97 |
+
pointmaps = blk(pointmaps, pos)
|
98 |
+
return pointmaps, pos
|
99 |
+
|
100 |
+
class RayQuery(nn.Module):
|
101 |
+
def __init__(self,
|
102 |
+
ray_enc=RayEncoder(),
|
103 |
+
pointmap_enc=PointmapEncoder(),
|
104 |
+
dec_pos_embed='RoPE100',
|
105 |
+
decoder_dim=256,
|
106 |
+
decoder_depth=3,
|
107 |
+
decoder_num_heads=4,
|
108 |
+
imshape=(128,128),
|
109 |
+
pts_head_type='dpt',
|
110 |
+
classifier_head_type='dpt_mask',
|
111 |
+
criterion=ConfLoss(L21),
|
112 |
+
return_all_blocks=True,
|
113 |
+
depth_mode=('exp',-float('inf'),float('inf')),
|
114 |
+
conf_mode=('exp',1,float('inf')),
|
115 |
+
classifier_mode=('raw',0,1),
|
116 |
+
dino_layers=[23],
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.ray_enc = ray_enc
|
120 |
+
self.pointmap_enc = pointmap_enc
|
121 |
+
self.dec_depth = decoder_depth
|
122 |
+
self.dec_embed_dim = decoder_dim
|
123 |
+
self.enc_embed_dim = ray_enc.dim
|
124 |
+
self.patch_size = pointmap_enc.patch_size
|
125 |
+
self.depth_mode = depth_mode
|
126 |
+
self.conf_mode = conf_mode
|
127 |
+
self.classifier_mode = classifier_mode
|
128 |
+
self.skip_dino = len(dino_layers) == 0
|
129 |
+
self.pts_head_type = pts_head_type
|
130 |
+
self.classifier_head_type = classifier_head_type
|
131 |
+
|
132 |
+
if dec_pos_embed.startswith('RoPE'):
|
133 |
+
self.dec_pos_embed = RoPE2D(freq=100.0)
|
134 |
+
else:
|
135 |
+
raise NotImplementedError(f'{dec_pos_embed} not implemented')
|
136 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(dim=decoder_dim, num_heads=decoder_num_heads,
|
137 |
+
rope=self.dec_pos_embed) for _ in range(decoder_depth)])
|
138 |
+
self.pts_head = head_factory(pts_head_type, 'pts3d', self, has_conf=True)
|
139 |
+
|
140 |
+
self.classifier_head = head_factory(classifier_head_type, 'pts3d', self, has_conf=True)
|
141 |
+
self.imshape = imshape
|
142 |
+
self.criterion = criterion
|
143 |
+
self.return_all_blocks = return_all_blocks
|
144 |
+
|
145 |
+
# dino projection
|
146 |
+
self.dino_layers = dino_layers
|
147 |
+
self.dino_proj = nn.Linear(1024 * len(dino_layers), decoder_dim)
|
148 |
+
self.dino_pos_getter = PositionGetter()
|
149 |
+
|
150 |
+
self.initialize_weights()
|
151 |
+
|
152 |
+
def initialize_weights(self):
|
153 |
+
self.apply(init_weights)
|
154 |
+
|
155 |
+
def forward_encoders(self, rays, pointmaps,masks=None):
|
156 |
+
# encode rays
|
157 |
+
rays, rays_pos = self.ray_enc(rays)
|
158 |
+
|
159 |
+
# encode pointmaps
|
160 |
+
B, H, W, C = pointmaps.shape
|
161 |
+
pointmaps = pointmaps.reshape(B,H,W,C) # each pointmap is encoded separately
|
162 |
+
pointmaps, pointmaps_pos = self.pointmap_enc(pointmaps,masks=masks)
|
163 |
+
new_shape = pointmaps.shape
|
164 |
+
pointmaps = pointmaps.reshape(new_shape[0],*new_shape[1:])
|
165 |
+
pointmaps_pos = pointmaps_pos[:B]
|
166 |
+
|
167 |
+
return rays, rays_pos, pointmaps, pointmaps_pos
|
168 |
+
|
169 |
+
def forward_decoder(self, rays, rays_pos, pointmaps, pointmaps_pos):
|
170 |
+
if self.return_all_blocks:
|
171 |
+
all_blocks = []
|
172 |
+
for blk in self.decoder_blocks:
|
173 |
+
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
|
174 |
+
all_blocks.append(rays)
|
175 |
+
return all_blocks
|
176 |
+
else:
|
177 |
+
for blk in self.decoder_blocks:
|
178 |
+
rays, pointmaps = blk(rays, pointmaps, rays_pos, pointmaps_pos)
|
179 |
+
return rays
|
180 |
+
|
181 |
+
def get_dino_pos(self,dino_features):
|
182 |
+
# dino runs on 14x14 patches
|
183 |
+
# note: assuming we cropped or resized down!
|
184 |
+
dino_H = self.imshape[0]//14
|
185 |
+
dino_W = self.imshape[1]//14
|
186 |
+
dino_pos = self.dino_pos_getter(dino_features.shape[0],dino_H,dino_W,dino_features.device)
|
187 |
+
return dino_pos
|
188 |
+
|
189 |
+
def forward(self,batch,mode='loss'):
|
190 |
+
# prep for encoders
|
191 |
+
rays = compute_rays(batch) # we are querying the first camera
|
192 |
+
pointmaps_context = batch['input_cams']['pointmaps'] # we are using the other cameras as context
|
193 |
+
input_masks = batch['input_cams']['valid_masks']
|
194 |
+
|
195 |
+
# run the encoders
|
196 |
+
rays, rays_pos, pointmaps, pointmaps_pos = self.forward_encoders(rays, pointmaps_context,masks=input_masks)
|
197 |
+
## adding dino features
|
198 |
+
if not self.skip_dino:
|
199 |
+
dino_features = batch['input_cams']['dino_features']
|
200 |
+
dino_features = self.dino_proj(dino_features)
|
201 |
+
if len(dino_features.shape) == 4:
|
202 |
+
dino_features = dino_features.squeeze(1)
|
203 |
+
dino_pos = self.get_dino_pos(dino_features)
|
204 |
+
pointmaps = torch.cat([pointmaps,dino_features],dim=1)
|
205 |
+
pointmaps_pos = torch.cat([pointmaps_pos,dino_pos],dim=1)
|
206 |
+
else:
|
207 |
+
dino_features = None
|
208 |
+
dino_pos = None
|
209 |
+
# decoder
|
210 |
+
rays = self.forward_decoder(rays, rays_pos, pointmaps, pointmaps_pos)
|
211 |
+
pts_pred_dict = self.pts_head(rays, self.imshape)
|
212 |
+
classifier_pred_dict = self.classifier_head(rays, self.imshape)
|
213 |
+
|
214 |
+
pred_dict = {**pts_pred_dict,**classifier_pred_dict}
|
215 |
+
gt_dict = batch['new_cams']
|
216 |
+
loss_dict = self.criterion(pred_dict, gt_dict)
|
217 |
+
|
218 |
+
del rays, rays_pos, pointmaps, pointmaps_pos, dino_features, dino_pos, pointmaps_context, input_masks, pts_pred_dict, classifier_pred_dict
|
219 |
+
|
220 |
+
if mode == 'loss':
|
221 |
+
# delete all the variables that are not needed
|
222 |
+
del pred_dict, gt_dict
|
223 |
+
return loss_dict
|
224 |
+
elif mode == 'viz':
|
225 |
+
return pred_dict, gt_dict, loss_dict
|
226 |
+
else:
|
227 |
+
raise ValueError(f"Invalid mode: {mode}")
|
readme.md
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center", documentation will follow later.
|
2 |
+
|
3 |
+
# RaySt3R: Predicting Novel Depth Maps for Zero-Shot Object Completion
|
4 |
+
|
5 |
+
<a href="https://arxiv.org/abs/2506.05285"><img src='https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white' alt='arXiv'></a>
|
6 |
+
<a href='https://rayst3r.github.io'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
|
7 |
+
|
8 |
+
</div>
|
9 |
+
|
10 |
+
<div align="center">
|
11 |
+
<img src="assets/overview.png" width="80%" alt="Method overview">
|
12 |
+
</div>
|
13 |
+
|
14 |
+
## 📚 Citation
|
15 |
+
```bibtex
|
16 |
+
@misc{rayst3r,
|
17 |
+
title={RaySt3R: Predicting Novel Depth Maps for Zero-Shot Object Completion},
|
18 |
+
author={Bardienus P. Duisterhof and Jan Oberst and Bowen Wen and Stan Birchfield and Deva Ramanan and Jeffrey Ichnowski},
|
19 |
+
year={2025},
|
20 |
+
eprint={2506.05285},
|
21 |
+
archivePrefix={arXiv},
|
22 |
+
primaryClass={cs.CV},
|
23 |
+
url={https://arxiv.org/abs/2506.05285},
|
24 |
+
}
|
25 |
+
```
|
26 |
+
## ✅ TO-DOs
|
27 |
+
|
28 |
+
- [x] Inference code
|
29 |
+
- [x] Local gradio demo
|
30 |
+
- [ ] Huggingface demo
|
31 |
+
- [ ] Docker
|
32 |
+
- [ ] Training code
|
33 |
+
- [ ] Eval code
|
34 |
+
- [ ] ViT-S, No-DINO and Pointmap models
|
35 |
+
- [ ] Dataset release
|
36 |
+
|
37 |
+
# ⚙️ Installation
|
38 |
+
|
39 |
+
```bash
|
40 |
+
mamba create -n rayst3r python=3.11 cmake=3.14.0
|
41 |
+
mamba activate rayst3r
|
42 |
+
mamba install pytorch torchvision pytorch-cuda=12.4 -c pytorch -c nvidia # change to your version of cuda
|
43 |
+
pip install -r requirements.txt
|
44 |
+
|
45 |
+
# compile the cuda kernels for RoPE
|
46 |
+
cd extensions/curope/
|
47 |
+
python setup.py build_ext --inplace
|
48 |
+
cd ../../
|
49 |
+
```
|
50 |
+
|
51 |
+
# 🚀 Usage
|
52 |
+
|
53 |
+
The expected input for RaySt3R is a folder with the following structure:
|
54 |
+
|
55 |
+
<pre><code>
|
56 |
+
📁 data_dir/
|
57 |
+
├── cam2world.pt # Camera-to-world transformation (PyTorch tensor), 4x4 - eye(4) if not provided
|
58 |
+
├── depth.png # Depth image, uint16 with max 10 meters
|
59 |
+
├── intrinsics.pt # Camera intrinsics (PyTorch tensor), 3x3
|
60 |
+
├── mask.png # Binary mask image
|
61 |
+
└── rgb.png # RGB image
|
62 |
+
</code></pre>
|
63 |
+
|
64 |
+
Note the depth image needs to be saved in uint16, normalized to a 0-10 meters range. We provide an example directory in `example_scene`.
|
65 |
+
Run RaySt3R with:
|
66 |
+
|
67 |
+
|
68 |
+
```bash
|
69 |
+
python3 eval_wrapper/eval.py example_scene/
|
70 |
+
```
|
71 |
+
This writes a colored point cloud back into the input directory.
|
72 |
+
|
73 |
+
Optional flags:
|
74 |
+
```bash
|
75 |
+
--visualize # Spins up a rerun client to visualize predictions and camera posees
|
76 |
+
--run_octmae # Novel views sampled with the OctMAE parameters (see paper)
|
77 |
+
--set_conf N # Sets confidence threshold to N
|
78 |
+
--n_pred_views # Number of predicted views along each axis in a grid, 5--> 22 views total
|
79 |
+
--filter_all_masks # Use all masks, point gets rejected if in background for a single mask
|
80 |
+
--tsdf # Fits TSDF to depth maps
|
81 |
+
```
|
82 |
+
|
83 |
+
# 🧪 Gradio app
|
84 |
+
|
85 |
+
We also provide a gradio app, which uses <a href="https://wangrc.site/MoGePage/">MoGe</a> and <a href="https://github.com/danielgatis/rembg">Rembg</a> to generate 3D from a single image.
|
86 |
+
|
87 |
+
Launch it with:
|
88 |
+
```bash
|
89 |
+
python app.py
|
90 |
+
```
|
91 |
+
|
92 |
+
# 🎛️ Parameter Guide
|
93 |
+
|
94 |
+
Certain applications may benefit from different hyper parameters, here we provide guidance on how to select them.
|
95 |
+
|
96 |
+
#### 🔁 View Sampling
|
97 |
+
|
98 |
+
We sample novel views evenly on a cylindrical equal-area projection of the sphere.
|
99 |
+
Customize sampling in <a href="eval_wrapper/sample_poses.py">sample_poses.py</a>. Use --n_pred_views to reduce the total number of views, making inference faster and reduce overlap and artifacts.
|
100 |
+
|
101 |
+
#### 🟢 Confidence Threshold
|
102 |
+
|
103 |
+
You can set the confidence threshold with the --set_conf threshold. As shown in the paper, a higher threshold generally improves accuracy, reduces edge bleeding but also affects completeness.
|
104 |
+
|
105 |
+
#### 🧼 RaySt3R Masks
|
106 |
+
|
107 |
+
On top of what was presented in the paper, we also provide the option to consider all predicted masks for each point. I.e., for any point, if any of the predicted masks classifies them as background the point gets removed.
|
108 |
+
In our limited testing this led to cleaner predictions, but it ocasinally carves out crucial parts of geometry.
|
109 |
+
|
110 |
+
# 🏋️ Training
|
111 |
+
|
112 |
+
The RaySt3R training command is provided in <a href="xps/train_rayst3r.py">train_rayst3r.py</a>, documentation will follow later.
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
open3d
|
4 |
+
Pillow
|
5 |
+
pyrender
|
6 |
+
rerun
|
7 |
+
setuptools
|
8 |
+
tqdm
|
9 |
+
trimesh
|
10 |
+
huggingface-hub
|
11 |
+
wandb
|
12 |
+
einops
|
13 |
+
|
14 |
+
# for app.py
|
15 |
+
onnxruntime
|
16 |
+
gradio
|
17 |
+
rembg
|
18 |
+
git+https://github.com/microsoft/MoGe.git
|
utils/augmentations.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from torchvision.transforms import GaussianBlur
|
6 |
+
from utils.batch_prep import compute_pointmaps
|
7 |
+
import imgaug as ia
|
8 |
+
import imgaug.augmenters as iaa
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
class ChangeBright(torch.nn.Module):
|
12 |
+
def __init__(self,prob=0.5,mag=[0.5,2.0]):
|
13 |
+
super().__init__()
|
14 |
+
self.mag = mag
|
15 |
+
self.prob = prob
|
16 |
+
|
17 |
+
def forward(self,rgb):
|
18 |
+
#if np.random.uniform()>=self.prob:
|
19 |
+
#return rgb
|
20 |
+
n = rgb.shape[0]
|
21 |
+
apply_aug = np.random.uniform(0,1,size=n) < self.prob
|
22 |
+
aug = iaa.MultiplyBrightness(np.random.uniform(self.mag[0],self.mag[1])) #NOTE iaa has bug about deterministic, we sample ourselves
|
23 |
+
rgb[apply_aug] = aug(images=rgb[apply_aug])
|
24 |
+
return rgb
|
25 |
+
|
26 |
+
class ChangeContrast(torch.nn.Module):
|
27 |
+
def __init__(self,prob=0.5,mag=[0.5,2.0]):
|
28 |
+
self.mag = mag
|
29 |
+
self.prob = prob
|
30 |
+
|
31 |
+
def __call__(self,rgb):
|
32 |
+
n = rgb.shape[0]
|
33 |
+
apply_aug = np.random.uniform(0,1,size=n) < self.prob
|
34 |
+
|
35 |
+
aug = iaa.GammaContrast(np.random.uniform(self.mag[0],self.mag[1]))
|
36 |
+
rgb[apply_aug] = aug(images=rgb[apply_aug])
|
37 |
+
return rgb
|
38 |
+
|
39 |
+
class SaltAndPepper:
|
40 |
+
def __init__(self, prob=0.3, ratio=0.1, per_channel=True):
|
41 |
+
self.prob = prob
|
42 |
+
self.ratio = ratio
|
43 |
+
self.per_channel = per_channel
|
44 |
+
|
45 |
+
def __call__(self, rgb):
|
46 |
+
n = rgb.shape[0]
|
47 |
+
apply_aug = np.random.uniform(0,1,size=n) < self.prob
|
48 |
+
aug = iaa.SaltAndPepper(self.ratio, per_channel=self.per_channel).to_deterministic()
|
49 |
+
rgb[apply_aug] = aug(images=rgb[apply_aug])
|
50 |
+
return rgb
|
51 |
+
|
52 |
+
class RGBGaussianNoise:
|
53 |
+
def __init__(self, max_noise=10, prob=0.5):
|
54 |
+
self.max_noise = max_noise
|
55 |
+
self.prob = prob
|
56 |
+
|
57 |
+
def __call__(self, rgb):
|
58 |
+
n = rgb.shape[0]
|
59 |
+
apply_aug = np.random.uniform(0,1,size=n) < self.prob
|
60 |
+
|
61 |
+
shape = rgb.shape
|
62 |
+
noise = np.random.normal(0, self.max_noise, size=shape).clip(-self.max_noise, self.max_noise)
|
63 |
+
rgb[apply_aug] = (rgb[apply_aug].astype(float) + noise[apply_aug]).clip(0,255).astype(np.uint8)
|
64 |
+
return rgb
|
65 |
+
|
66 |
+
# from https://github.com/mihdalal/manipgen/blob/master/manipgen/utils/obs_utils.py
|
67 |
+
class DepthWarping(torch.nn.Module):
|
68 |
+
def __init__(self, std=0.5, prob=0.8):
|
69 |
+
super().__init__()
|
70 |
+
self.std = std
|
71 |
+
self.prob = prob
|
72 |
+
|
73 |
+
def forward(self, depths, device=None):
|
74 |
+
if device is None:
|
75 |
+
device = depths.device
|
76 |
+
|
77 |
+
n, _, h, w = depths.shape
|
78 |
+
|
79 |
+
# Generate Gaussian shifts
|
80 |
+
gaussian_shifts = torch.normal(mean=0, std=self.std, size=(n, h, w, 2), device=device).float()
|
81 |
+
apply_shifts = torch.rand(n, device=device) < self.prob
|
82 |
+
gaussian_shifts[~apply_shifts] = 0.0
|
83 |
+
|
84 |
+
# Create grid for the original coordinates
|
85 |
+
xx = torch.linspace(0, w - 1, w, device=device)
|
86 |
+
yy = torch.linspace(0, h - 1, h, device=device)
|
87 |
+
xx = xx.unsqueeze(0).repeat(h, 1)
|
88 |
+
yy = yy.unsqueeze(1).repeat(1, w)
|
89 |
+
grid = torch.stack((xx, yy), 2).unsqueeze(0) # Add batch dimension
|
90 |
+
|
91 |
+
# Apply Gaussian shifts to the grid
|
92 |
+
grid = grid + gaussian_shifts
|
93 |
+
|
94 |
+
# Normalize grid values to the range [-1, 1] for grid_sample
|
95 |
+
grid[..., 0] = (grid[..., 0] / (w - 1)) * 2 - 1
|
96 |
+
grid[..., 1] = (grid[..., 1] / (h - 1)) * 2 - 1
|
97 |
+
|
98 |
+
# Perform the remapping using grid_sample
|
99 |
+
depth_interp = F.grid_sample(depths, grid, mode='bilinear', padding_mode='border', align_corners=True)
|
100 |
+
|
101 |
+
# Remove the batch and channel dimensions
|
102 |
+
depth_interp = depth_interp.squeeze(0).squeeze(0)
|
103 |
+
|
104 |
+
return depth_interp
|
105 |
+
|
106 |
+
class DepthHoles(torch.nn.Module):
|
107 |
+
def __init__(self, prob=0.5, kernel_size_lower=3, kernel_size_upper=27, sigma_lower=1.0,
|
108 |
+
sigma_upper=7.0, thresh_lower=0.6, thresh_upper=0.9):
|
109 |
+
super().__init__()
|
110 |
+
self.prob = prob
|
111 |
+
self.kernel_size_lower = kernel_size_lower
|
112 |
+
self.kernel_size_upper = kernel_size_upper
|
113 |
+
self.sigma_lower = sigma_lower
|
114 |
+
self.sigma_upper = sigma_upper
|
115 |
+
self.thresh_lower = thresh_lower
|
116 |
+
self.thresh_upper = thresh_upper
|
117 |
+
|
118 |
+
def forward(self, depths, device=None):
|
119 |
+
if device is None:
|
120 |
+
device = depths.device
|
121 |
+
|
122 |
+
n, _, h, w = depths.shape
|
123 |
+
# generate random noise
|
124 |
+
noise = torch.rand(n, 1, h, w, device=device)
|
125 |
+
|
126 |
+
# apply gaussian blur
|
127 |
+
k = random.choice(list(range(self.kernel_size_lower, self.kernel_size_upper+1, 2)))
|
128 |
+
noise = GaussianBlur(kernel_size=k, sigma=(self.sigma_lower, self.sigma_upper))(noise)
|
129 |
+
|
130 |
+
# normalize noise
|
131 |
+
noise = (noise - noise.min()) / (noise.max() - noise.min())
|
132 |
+
|
133 |
+
# apply thresholding
|
134 |
+
thresh = torch.rand(n, 1, 1, 1, device=device) * (self.thresh_upper - self.thresh_lower) + self.thresh_lower
|
135 |
+
mask = (noise > thresh)
|
136 |
+
prob = self.prob
|
137 |
+
keep_mask = torch.rand(n, device=device) < prob
|
138 |
+
mask[~keep_mask, :] = 0
|
139 |
+
|
140 |
+
return mask
|
141 |
+
|
142 |
+
class DepthNoise(torch.nn.Module):
|
143 |
+
def __init__(self, std=0.005,prob=1.0):
|
144 |
+
super().__init__()
|
145 |
+
self.std = std
|
146 |
+
self.prob = prob
|
147 |
+
|
148 |
+
def forward(self, depths, device=None):
|
149 |
+
if device is None:
|
150 |
+
device = depths.device
|
151 |
+
|
152 |
+
n, _, h, w = depths.shape
|
153 |
+
apply_noise = torch.rand(n, device=device) < self.prob
|
154 |
+
noise = torch.randn(n, 1, h, w, device=device) * self.std
|
155 |
+
noise[~apply_noise] = 0.0
|
156 |
+
return depths + noise
|
157 |
+
|
158 |
+
class Augmentor(torch.nn.Module):
|
159 |
+
def __init__(self, depth_holes=DepthHoles(), depth_warping=DepthWarping(),depth_noise=DepthNoise(),
|
160 |
+
rgb_operators=[ChangeBright(),SaltAndPepper(),ChangeContrast(),RGBGaussianNoise()]):
|
161 |
+
super().__init__()
|
162 |
+
self.depth_holes = depth_holes
|
163 |
+
self.depth_warping = depth_warping
|
164 |
+
self.depth_noise = depth_noise
|
165 |
+
self.rgb_operators = rgb_operators
|
166 |
+
|
167 |
+
def forward(self, batch):
|
168 |
+
input_depths = batch['input_cams']['depths']
|
169 |
+
if self.depth_holes.prob > 0:
|
170 |
+
masks = self.depth_holes(input_depths)
|
171 |
+
batch['input_cams']['valid_masks'][masks] = False
|
172 |
+
#if self.depth_warping.prob > 0:
|
173 |
+
#input_depths = self.depth_warping(input_depths)
|
174 |
+
if self.depth_noise.prob > 0:
|
175 |
+
input_depths = self.depth_noise(input_depths)
|
176 |
+
|
177 |
+
input_rgbs = batch['input_cams']['imgs'].squeeze(1).cpu().numpy() # this is a bit inefficient, but it's ok..
|
178 |
+
for op in self.rgb_operators:
|
179 |
+
input_rgbs = op(input_rgbs)
|
180 |
+
batch['input_cams']['imgs'] = torch.from_numpy(input_rgbs).cuda().unsqueeze(1)
|
181 |
+
|
182 |
+
batch['input_cams']['depths'] = input_depths
|
183 |
+
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws']) # now we're doing this twice, but alas
|
184 |
+
return batch
|
utils/batch_prep.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as tvf
|
3 |
+
|
4 |
+
dino_patch_size = 14
|
5 |
+
|
6 |
+
def batch_to_device(batch,device='cuda'):
|
7 |
+
for key in batch:
|
8 |
+
if isinstance(batch[key],torch.Tensor):
|
9 |
+
batch[key] = batch[key].to(device)
|
10 |
+
elif isinstance(batch[key],dict):
|
11 |
+
batch[key] = batch_to_device(batch[key],device)
|
12 |
+
return batch
|
13 |
+
|
14 |
+
|
15 |
+
def compute_pointmap(depth: torch.Tensor, intrinsics: torch.Tensor, cam2world: torch.Tensor = None) -> torch.Tensor:
|
16 |
+
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
17 |
+
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
18 |
+
h, w = depth.shape
|
19 |
+
|
20 |
+
i, j = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy')
|
21 |
+
i = i.to(depth.device)
|
22 |
+
j = j.to(depth.device)
|
23 |
+
|
24 |
+
x_cam = (i - cx) * depth / fx
|
25 |
+
y_cam = (j - cy) * depth / fy
|
26 |
+
|
27 |
+
points_cam = torch.stack([x_cam, y_cam, depth], axis=-1)
|
28 |
+
|
29 |
+
if cam2world is not None:
|
30 |
+
points_cam = torch.matmul(cam2world[:3, :3], points_cam.reshape(-1, 3).T).T + cam2world[:3, 3]
|
31 |
+
points_cam = points_cam.reshape(h, w, 3)
|
32 |
+
|
33 |
+
return points_cam
|
34 |
+
|
35 |
+
def compute_pointmaps(depths: torch.Tensor, intrinsics: torch.Tensor, cam2worlds: torch.Tensor) -> torch.Tensor:
|
36 |
+
pointmaps = []
|
37 |
+
depth_shape = depths.shape
|
38 |
+
pointmaps_shape = depths.shape + (3,)
|
39 |
+
for depth, K, c2w in zip(depths, intrinsics, cam2worlds):
|
40 |
+
n_views = depth.shape[0]
|
41 |
+
for i in range(n_views):
|
42 |
+
pointmaps.append(compute_pointmap(depth[i], K[i],c2w[i]))
|
43 |
+
return torch.stack(pointmaps).reshape(pointmaps_shape)
|
44 |
+
|
45 |
+
def depth_to_metric(depth):
|
46 |
+
# depth: shape H x W
|
47 |
+
# we want to convert the depth to a metric depth
|
48 |
+
depth_max = 10.0
|
49 |
+
depth_scaled = depth_max * (depth / 65535.0)
|
50 |
+
|
51 |
+
return depth_scaled
|
52 |
+
|
53 |
+
def make_rgb_transform() -> tvf.Compose:
|
54 |
+
return tvf.Compose([
|
55 |
+
#tvf.ToTensor(),
|
56 |
+
#lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
|
57 |
+
tvf.Normalize(
|
58 |
+
mean=(123.675, 116.28, 103.53),
|
59 |
+
std=(58.395, 57.12, 57.375),
|
60 |
+
),
|
61 |
+
])
|
62 |
+
|
63 |
+
rgb_transform = make_rgb_transform()
|
64 |
+
|
65 |
+
def compute_dino_and_store_features(dino_model : torch.nn.Module, rgb: torch.Tensor, mask: torch.Tensor,dino_layers: list[int] = None) -> torch.Tensor:
|
66 |
+
"""Computes the DINO features given an RGB image."""
|
67 |
+
rgb = rgb.squeeze(1)
|
68 |
+
mask = mask.squeeze(1)
|
69 |
+
rgb = rgb.permute(0,3,1,2)
|
70 |
+
mask = mask.unsqueeze(1).repeat(1,3,1,1)
|
71 |
+
rgb = rgb * mask
|
72 |
+
|
73 |
+
rgb = rgb.float()
|
74 |
+
H, W = rgb.shape[-2:]
|
75 |
+
goal_H, goal_W = H//dino_patch_size*dino_patch_size, W//dino_patch_size*dino_patch_size
|
76 |
+
resize_transform = tvf.CenterCrop([goal_H, goal_W])
|
77 |
+
with torch.no_grad():
|
78 |
+
rgb = resize_transform(rgb)
|
79 |
+
rgb = rgb_transform(rgb)
|
80 |
+
all_feat = dino_model.get_intermediate_layers(rgb, dino_layers)
|
81 |
+
dino_feat = torch.cat(all_feat, dim=-1)
|
82 |
+
return dino_feat
|
83 |
+
|
84 |
+
|
85 |
+
def prepare_fast_batch(batch,dino_model = None,dino_layers = None):
|
86 |
+
# depth to metric
|
87 |
+
batch['new_cams']['depths'] = depth_to_metric(batch['new_cams']['depths'])
|
88 |
+
batch['input_cams']['depths'] = depth_to_metric(batch['input_cams']['depths'])
|
89 |
+
|
90 |
+
# compute pointmaps
|
91 |
+
batch['new_cams']['pointmaps'] = compute_pointmaps(batch['new_cams']['depths'],batch['new_cams']['Ks'],batch['new_cams']['c2ws'])
|
92 |
+
batch['input_cams']['pointmaps'] = compute_pointmaps(batch['input_cams']['depths'],batch['input_cams']['Ks'],batch['input_cams']['c2ws'])
|
93 |
+
|
94 |
+
# compute dino features
|
95 |
+
if dino_model is not None and len(dino_layers) > 0:
|
96 |
+
batch['input_cams']['dino_features'] = compute_dino_and_store_features(dino_model,batch['input_cams']['imgs'],batch['input_cams']['valid_masks'],dino_layers)
|
97 |
+
|
98 |
+
return batch
|
99 |
+
|
100 |
+
|
101 |
+
def normalize_batch(batch,normalize_mode):
|
102 |
+
scale_factors = []
|
103 |
+
if normalize_mode == 'None':
|
104 |
+
pass
|
105 |
+
elif normalize_mode == 'median':
|
106 |
+
B = batch['input_cams']['valid_masks'].shape[0]
|
107 |
+
for b in range(B):
|
108 |
+
input_mask = batch['input_cams']['valid_masks'][b]
|
109 |
+
depth_median = batch['input_cams']['depths'][b][input_mask].median()
|
110 |
+
scale_factor = 1.0 / depth_median
|
111 |
+
scale_factors.append(scale_factor)
|
112 |
+
batch['input_cams']['depths'][b] = scale_factor * batch['input_cams']['depths'][b]
|
113 |
+
batch['input_cams']['pointmaps'][b] = scale_factor * batch['input_cams']['pointmaps'][b]
|
114 |
+
batch['input_cams']['c2ws'][b][0,:3,-1] = scale_factor * batch['input_cams']['c2ws'][b][0,:3,-1]
|
115 |
+
|
116 |
+
batch['new_cams']['depths'][b] = scale_factor * batch['new_cams']['depths'][b]
|
117 |
+
batch['new_cams']['pointmaps'][b] = scale_factor * batch['new_cams']['pointmaps'][b]
|
118 |
+
batch['new_cams']['c2ws'][b][:,:3,-1] = scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
|
119 |
+
|
120 |
+
return batch, scale_factors
|
121 |
+
|
122 |
+
def denormalize_batch(batch,pred,gt,scale_factors):
|
123 |
+
B = len(scale_factors)
|
124 |
+
n_new_cams = batch['new_cams']['c2ws'].shape[1]
|
125 |
+
for b in range(B):
|
126 |
+
new_scale_factor = 1.0 / scale_factors[b]
|
127 |
+
batch['input_cams']['depths'][b] = new_scale_factor * batch['input_cams']['depths'][b]
|
128 |
+
batch['input_cams']['pointmaps'][b] = new_scale_factor * batch['input_cams']['pointmaps'][b]
|
129 |
+
batch['input_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['input_cams']['c2ws'][b][:,:3,-1]
|
130 |
+
batch['new_cams']['depths'][b] = new_scale_factor * batch['new_cams']['depths'][b]
|
131 |
+
batch['new_cams']['pointmaps'][b] = new_scale_factor * batch['new_cams']['pointmaps'][b]
|
132 |
+
batch['new_cams']['c2ws'][b][:,:3,-1] = new_scale_factor * batch['new_cams']['c2ws'][b][:,:3,-1]
|
133 |
+
|
134 |
+
pred['depths'][b] = new_scale_factor * pred['depths'][b]
|
135 |
+
|
136 |
+
gt['c2ws'][b][:,:3,-1] = new_scale_factor * gt['c2ws'][b][:,:3,-1]
|
137 |
+
gt['depths'][b] = new_scale_factor * gt['depths'][b]
|
138 |
+
|
139 |
+
gt['pointmaps'][b] = compute_pointmaps(gt['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
|
140 |
+
pred['pointmaps'][b] = compute_pointmaps(pred['depths'][b].unsqueeze(1),gt['Ks'][b].unsqueeze(1),gt['c2ws'][b].unsqueeze(1)).squeeze(1)
|
141 |
+
return batch, pred, gt
|
utils/collate.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def collate(batch):
|
4 |
+
if isinstance(batch[0],dict):
|
5 |
+
return {k: collate([d[k] for d in batch]) for k in batch[0].keys()}
|
6 |
+
else:
|
7 |
+
return torch.stack([torch.stack(t) for t in batch])
|
utils/eval.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def eval_pred(pred_dict, gt_dict,accuracy_tresh=[0.001,0.01,0.02,0.05,0.1,0.5]):
|
4 |
+
pointmaps_pred = pred_dict['pointmaps']
|
5 |
+
pointmaps_gt = gt_dict['pointmaps']
|
6 |
+
mask = gt_dict['valid_masks'].unsqueeze(-1).repeat(1,1,1,3)
|
7 |
+
|
8 |
+
points_pred = pointmaps_pred[mask].reshape(-1,3)
|
9 |
+
points_gt = pointmaps_gt[mask].reshape(-1,3)
|
10 |
+
dists = torch.norm(points_pred - points_gt, dim=1)
|
11 |
+
results = {'dist':dists.mean().detach().item()}
|
12 |
+
if 'classifier' in pred_dict:
|
13 |
+
classifier_pred = (torch.sigmoid(pred_dict['classifier']) > 0.5).bool()
|
14 |
+
classifier_gt = gt_dict['valid_masks']
|
15 |
+
results['classifier_acc'] = (classifier_pred == classifier_gt).float().mean().detach().item()
|
16 |
+
|
17 |
+
for tresh in accuracy_tresh:
|
18 |
+
acc = (dists < tresh).float().mean()
|
19 |
+
results[f'acc_{tresh}'] = acc.detach().item()
|
20 |
+
return results
|
utils/fusion.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018 Andy Zeng
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from numba import njit, prange
|
6 |
+
from skimage import measure
|
7 |
+
|
8 |
+
try:
|
9 |
+
import pycuda.driver as cuda
|
10 |
+
import pycuda.autoinit
|
11 |
+
from pycuda.compiler import SourceModule
|
12 |
+
FUSION_GPU_MODE = 1
|
13 |
+
except Exception as err:
|
14 |
+
print('Warning: {}'.format(err))
|
15 |
+
print('Failed to import PyCUDA. Running fusion in CPU mode.')
|
16 |
+
FUSION_GPU_MODE = 0
|
17 |
+
|
18 |
+
|
19 |
+
class TSDFVolume:
|
20 |
+
"""Volumetric TSDF Fusion of RGB-D Images.
|
21 |
+
"""
|
22 |
+
def __init__(self, vol_bnds, voxel_size, use_gpu=True):
|
23 |
+
"""Constructor.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the
|
27 |
+
xyz bounds (min/max) in meters.
|
28 |
+
voxel_size (float): The volume discretization in meters.
|
29 |
+
"""
|
30 |
+
vol_bnds = np.asarray(vol_bnds)
|
31 |
+
assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)."
|
32 |
+
|
33 |
+
# Define voxel volume parameters
|
34 |
+
self._vol_bnds = vol_bnds
|
35 |
+
self._voxel_size = float(voxel_size)
|
36 |
+
self._trunc_margin = 5 * self._voxel_size # truncation on SDF
|
37 |
+
self._color_const = 256 * 256
|
38 |
+
|
39 |
+
# Adjust volume bounds and ensure C-order contiguous
|
40 |
+
self._vol_dim = np.ceil((self._vol_bnds[:,1]-self._vol_bnds[:,0])/self._voxel_size).copy(order='C').astype(int)
|
41 |
+
self._vol_bnds[:,1] = self._vol_bnds[:,0]+self._vol_dim*self._voxel_size
|
42 |
+
self._vol_origin = self._vol_bnds[:,0].copy(order='C').astype(np.float32)
|
43 |
+
|
44 |
+
print("Voxel volume size: {} x {} x {} - # points: {:,}".format(
|
45 |
+
self._vol_dim[0], self._vol_dim[1], self._vol_dim[2],
|
46 |
+
self._vol_dim[0]*self._vol_dim[1]*self._vol_dim[2])
|
47 |
+
)
|
48 |
+
|
49 |
+
# Initialize pointers to voxel volume in CPU memory
|
50 |
+
self._tsdf_vol_cpu = np.ones(self._vol_dim).astype(np.float32)
|
51 |
+
# for computing the cumulative moving average of observations per voxel
|
52 |
+
self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
|
53 |
+
self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
|
54 |
+
|
55 |
+
#self.gpu_mode = False # CPU for debugging!!
|
56 |
+
self.gpu_mode = use_gpu and FUSION_GPU_MODE
|
57 |
+
|
58 |
+
# Copy voxel volumes to GPU
|
59 |
+
if self.gpu_mode:
|
60 |
+
self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes)
|
61 |
+
cuda.memcpy_htod(self._tsdf_vol_gpu,self._tsdf_vol_cpu)
|
62 |
+
self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes)
|
63 |
+
cuda.memcpy_htod(self._weight_vol_gpu,self._weight_vol_cpu)
|
64 |
+
self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes)
|
65 |
+
cuda.memcpy_htod(self._color_vol_gpu,self._color_vol_cpu)
|
66 |
+
|
67 |
+
# Cuda kernel function (C++)
|
68 |
+
self._cuda_src_mod = SourceModule("""
|
69 |
+
__global__ void integrate(float * tsdf_vol,
|
70 |
+
float * weight_vol,
|
71 |
+
float * color_vol,
|
72 |
+
float * vol_dim,
|
73 |
+
float * vol_origin,
|
74 |
+
float * cam_intr,
|
75 |
+
float * cam_pose,
|
76 |
+
float * other_params,
|
77 |
+
float * color_im,
|
78 |
+
float * depth_im) {
|
79 |
+
// Get voxel index
|
80 |
+
int gpu_loop_idx = (int) other_params[0];
|
81 |
+
int max_threads_per_block = blockDim.x;
|
82 |
+
int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x;
|
83 |
+
int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x;
|
84 |
+
int vol_dim_x = (int) vol_dim[0];
|
85 |
+
int vol_dim_y = (int) vol_dim[1];
|
86 |
+
int vol_dim_z = (int) vol_dim[2];
|
87 |
+
if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z)
|
88 |
+
return;
|
89 |
+
// Get voxel grid coordinates (note: be careful when casting)
|
90 |
+
float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z)));
|
91 |
+
float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z));
|
92 |
+
float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z);
|
93 |
+
// Voxel grid coordinates to world coordinates
|
94 |
+
float voxel_size = other_params[1];
|
95 |
+
float pt_x = vol_origin[0]+voxel_x*voxel_size;
|
96 |
+
float pt_y = vol_origin[1]+voxel_y*voxel_size;
|
97 |
+
float pt_z = vol_origin[2]+voxel_z*voxel_size;
|
98 |
+
// World coordinates to camera coordinates
|
99 |
+
float tmp_pt_x = pt_x-cam_pose[0*4+3];
|
100 |
+
float tmp_pt_y = pt_y-cam_pose[1*4+3];
|
101 |
+
float tmp_pt_z = pt_z-cam_pose[2*4+3];
|
102 |
+
float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z;
|
103 |
+
float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z;
|
104 |
+
float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z;
|
105 |
+
// Camera coordinates to image pixels
|
106 |
+
int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]);
|
107 |
+
int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]);
|
108 |
+
// Skip if outside view frustum
|
109 |
+
int im_h = (int) other_params[2];
|
110 |
+
int im_w = (int) other_params[3];
|
111 |
+
if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0)
|
112 |
+
return;
|
113 |
+
// Skip invalid depth
|
114 |
+
float depth_value = depth_im[pixel_y*im_w+pixel_x];
|
115 |
+
if (depth_value == 0)
|
116 |
+
return;
|
117 |
+
// Integrate TSDF
|
118 |
+
float trunc_margin = other_params[4];
|
119 |
+
float depth_diff = depth_value-cam_pt_z;
|
120 |
+
if (depth_diff < -trunc_margin)
|
121 |
+
return;
|
122 |
+
float dist = fmin(1.0f,depth_diff/trunc_margin);
|
123 |
+
float w_old = weight_vol[voxel_idx];
|
124 |
+
float obs_weight = other_params[5];
|
125 |
+
float w_new = w_old + obs_weight;
|
126 |
+
weight_vol[voxel_idx] = w_new;
|
127 |
+
tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new;
|
128 |
+
// Integrate color
|
129 |
+
float old_color = color_vol[voxel_idx];
|
130 |
+
float old_b = floorf(old_color/(256*256));
|
131 |
+
float old_g = floorf((old_color-old_b*256*256)/256);
|
132 |
+
float old_r = old_color-old_b*256*256-old_g*256;
|
133 |
+
float new_color = color_im[pixel_y*im_w+pixel_x];
|
134 |
+
float new_b = floorf(new_color/(256*256));
|
135 |
+
float new_g = floorf((new_color-new_b*256*256)/256);
|
136 |
+
float new_r = new_color-new_b*256*256-new_g*256;
|
137 |
+
new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f);
|
138 |
+
new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f);
|
139 |
+
new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f);
|
140 |
+
color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r;
|
141 |
+
}""")
|
142 |
+
|
143 |
+
self._cuda_integrate = self._cuda_src_mod.get_function("integrate")
|
144 |
+
|
145 |
+
# Determine block/grid size on GPU
|
146 |
+
gpu_dev = cuda.Device(0)
|
147 |
+
self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK
|
148 |
+
n_blocks = int(np.ceil(float(np.prod(self._vol_dim))/float(self._max_gpu_threads_per_block)))
|
149 |
+
grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X,int(np.floor(np.cbrt(n_blocks))))
|
150 |
+
grid_dim_y = min(gpu_dev.MAX_GRID_DIM_Y,int(np.floor(np.sqrt(n_blocks/grid_dim_x))))
|
151 |
+
grid_dim_z = min(gpu_dev.MAX_GRID_DIM_Z,int(np.ceil(float(n_blocks)/float(grid_dim_x*grid_dim_y))))
|
152 |
+
self._max_gpu_grid_dim = np.array([grid_dim_x,grid_dim_y,grid_dim_z]).astype(int)
|
153 |
+
self._n_gpu_loops = int(np.ceil(float(np.prod(self._vol_dim))/float(np.prod(self._max_gpu_grid_dim)*self._max_gpu_threads_per_block)))
|
154 |
+
|
155 |
+
else:
|
156 |
+
# Get voxel grid coordinates
|
157 |
+
xv, yv, zv = np.meshgrid(
|
158 |
+
range(self._vol_dim[0]),
|
159 |
+
range(self._vol_dim[1]),
|
160 |
+
range(self._vol_dim[2]),
|
161 |
+
indexing='ij'
|
162 |
+
)
|
163 |
+
self.vox_coords = np.concatenate([
|
164 |
+
xv.reshape(1,-1),
|
165 |
+
yv.reshape(1,-1),
|
166 |
+
zv.reshape(1,-1)
|
167 |
+
], axis=0).astype(int).T
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
@njit(parallel=True)
|
171 |
+
def vox2world(vol_origin, vox_coords, vox_size):
|
172 |
+
"""Convert voxel grid coordinates to world coordinates.
|
173 |
+
"""
|
174 |
+
vol_origin = vol_origin.astype(np.float32)
|
175 |
+
vox_coords = vox_coords.astype(np.float32)
|
176 |
+
cam_pts = np.empty_like(vox_coords, dtype=np.float32)
|
177 |
+
for i in prange(vox_coords.shape[0]):
|
178 |
+
for j in range(3):
|
179 |
+
cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j])
|
180 |
+
return cam_pts
|
181 |
+
|
182 |
+
@staticmethod
|
183 |
+
@njit(parallel=True)
|
184 |
+
def cam2pix(cam_pts, intr):
|
185 |
+
"""Convert camera coordinates to pixel coordinates.
|
186 |
+
"""
|
187 |
+
intr = intr.astype(np.float32)
|
188 |
+
fx, fy = intr[0, 0], intr[1, 1]
|
189 |
+
cx, cy = intr[0, 2], intr[1, 2]
|
190 |
+
pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
|
191 |
+
for i in prange(cam_pts.shape[0]):
|
192 |
+
pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
|
193 |
+
pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
|
194 |
+
return pix
|
195 |
+
|
196 |
+
@staticmethod
|
197 |
+
@njit(parallel=True)
|
198 |
+
def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight):
|
199 |
+
"""Integrate the TSDF volume.
|
200 |
+
"""
|
201 |
+
tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32)
|
202 |
+
w_new = np.empty_like(w_old, dtype=np.float32)
|
203 |
+
for i in prange(len(tsdf_vol)):
|
204 |
+
w_new[i] = w_old[i] + obs_weight
|
205 |
+
tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i]
|
206 |
+
return tsdf_vol_int, w_new
|
207 |
+
|
208 |
+
def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.,mask=None):
|
209 |
+
"""Integrate an RGB-D frame into the TSDF volume.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
color_im (ndarray): An RGB image of shape (H, W, 3).
|
213 |
+
depth_im (ndarray): A depth image of shape (H, W).
|
214 |
+
cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3).
|
215 |
+
cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4).
|
216 |
+
obs_weight (float): The weight to assign for the current observation. A higher
|
217 |
+
value
|
218 |
+
"""
|
219 |
+
im_h, im_w = depth_im.shape
|
220 |
+
|
221 |
+
# Fold RGB color image into a single channel image
|
222 |
+
color_im = color_im.astype(np.float32)
|
223 |
+
color_im = np.floor(color_im[...,2]*self._color_const + color_im[...,1]*256 + color_im[...,0])
|
224 |
+
|
225 |
+
if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel)
|
226 |
+
# no mask implemented yet
|
227 |
+
for gpu_loop_idx in range(self._n_gpu_loops):
|
228 |
+
self._cuda_integrate(self._tsdf_vol_gpu,
|
229 |
+
self._weight_vol_gpu,
|
230 |
+
self._color_vol_gpu,
|
231 |
+
cuda.InOut(self._vol_dim.astype(np.float32)),
|
232 |
+
cuda.InOut(self._vol_origin.astype(np.float32)),
|
233 |
+
cuda.InOut(cam_intr.reshape(-1).astype(np.float32)),
|
234 |
+
cuda.InOut(cam_pose.reshape(-1).astype(np.float32)),
|
235 |
+
cuda.InOut(np.asarray([
|
236 |
+
gpu_loop_idx,
|
237 |
+
self._voxel_size,
|
238 |
+
im_h,
|
239 |
+
im_w,
|
240 |
+
self._trunc_margin,
|
241 |
+
obs_weight
|
242 |
+
], np.float32)),
|
243 |
+
cuda.InOut(color_im.reshape(-1).astype(np.float32)),
|
244 |
+
cuda.InOut(depth_im.reshape(-1).astype(np.float32)),
|
245 |
+
block=(self._max_gpu_threads_per_block,1,1),
|
246 |
+
grid=(
|
247 |
+
int(self._max_gpu_grid_dim[0]),
|
248 |
+
int(self._max_gpu_grid_dim[1]),
|
249 |
+
int(self._max_gpu_grid_dim[2]),
|
250 |
+
)
|
251 |
+
)
|
252 |
+
else: # CPU mode: integrate voxel volume (vectorized implementation)
|
253 |
+
# Convert voxel grid coordinates to pixel coordinates
|
254 |
+
cam_pts = self.vox2world(self._vol_origin, self.vox_coords, self._voxel_size)
|
255 |
+
cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose))
|
256 |
+
pix_z = cam_pts[:, 2]
|
257 |
+
pix = self.cam2pix(cam_pts, cam_intr)
|
258 |
+
pix_x, pix_y = pix[:, 0], pix[:, 1]
|
259 |
+
|
260 |
+
# Eliminate pixels outside view frustum
|
261 |
+
valid_pix = np.logical_and(pix_x >= 0,
|
262 |
+
np.logical_and(pix_x < im_w,
|
263 |
+
np.logical_and(pix_y >= 0,
|
264 |
+
np.logical_and(pix_y < im_h,
|
265 |
+
pix_z > 0))))
|
266 |
+
if mask is not None:
|
267 |
+
mask_queries = mask[pix_y[valid_pix],pix_x[valid_pix]]
|
268 |
+
valid_pix[valid_pix] = np.logical_and(valid_pix[valid_pix],mask_queries)
|
269 |
+
|
270 |
+
depth_val = np.zeros(pix_x.shape)
|
271 |
+
depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]]
|
272 |
+
|
273 |
+
# Integrate TSDF
|
274 |
+
depth_diff = depth_val - pix_z
|
275 |
+
valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin)
|
276 |
+
dist = np.minimum(1, depth_diff / self._trunc_margin)
|
277 |
+
valid_vox_x = self.vox_coords[valid_pts, 0]
|
278 |
+
valid_vox_y = self.vox_coords[valid_pts, 1]
|
279 |
+
valid_vox_z = self.vox_coords[valid_pts, 2]
|
280 |
+
w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
|
281 |
+
tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
|
282 |
+
valid_dist = dist[valid_pts]
|
283 |
+
tsdf_vol_new, w_new = self.integrate_tsdf(tsdf_vals, valid_dist, w_old, obs_weight)
|
284 |
+
self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
|
285 |
+
self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new
|
286 |
+
|
287 |
+
# Integrate color
|
288 |
+
old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
|
289 |
+
old_b = np.floor(old_color / self._color_const)
|
290 |
+
old_g = np.floor((old_color-old_b*self._color_const)/256)
|
291 |
+
old_r = old_color - old_b*self._color_const - old_g*256
|
292 |
+
new_color = color_im[pix_y[valid_pts],pix_x[valid_pts]]
|
293 |
+
new_b = np.floor(new_color / self._color_const)
|
294 |
+
new_g = np.floor((new_color - new_b*self._color_const) /256)
|
295 |
+
new_r = new_color - new_b*self._color_const - new_g*256
|
296 |
+
new_b = np.minimum(255., np.round((w_old*old_b + obs_weight*new_b) / w_new))
|
297 |
+
new_g = np.minimum(255., np.round((w_old*old_g + obs_weight*new_g) / w_new))
|
298 |
+
new_r = np.minimum(255., np.round((w_old*old_r + obs_weight*new_r) / w_new))
|
299 |
+
self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = new_b*self._color_const + new_g*256 + new_r
|
300 |
+
|
301 |
+
def get_volume(self):
|
302 |
+
if self.gpu_mode:
|
303 |
+
cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu)
|
304 |
+
cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu)
|
305 |
+
return self._tsdf_vol_cpu, self._color_vol_cpu
|
306 |
+
|
307 |
+
def get_point_cloud(self):
|
308 |
+
"""Extract a point cloud from the voxel volume.
|
309 |
+
"""
|
310 |
+
tsdf_vol, color_vol = self.get_volume()
|
311 |
+
|
312 |
+
# Marching cubes
|
313 |
+
verts = measure.marching_cubes(tsdf_vol, level=0, method='lewiner')[0]
|
314 |
+
verts_ind = np.round(verts).astype(int)
|
315 |
+
verts = verts*self._voxel_size + self._vol_origin
|
316 |
+
|
317 |
+
# Get vertex colors
|
318 |
+
rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
|
319 |
+
colors_b = np.floor(rgb_vals / self._color_const)
|
320 |
+
colors_g = np.floor((rgb_vals - colors_b*self._color_const) / 256)
|
321 |
+
colors_r = rgb_vals - colors_b*self._color_const - colors_g*256
|
322 |
+
colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
|
323 |
+
colors = colors.astype(np.uint8)
|
324 |
+
|
325 |
+
pc = np.hstack([verts, colors])
|
326 |
+
return pc
|
327 |
+
|
328 |
+
def get_mesh(self):
|
329 |
+
"""Compute a mesh from the voxel volume using marching cubes.
|
330 |
+
"""
|
331 |
+
tsdf_vol, color_vol = self.get_volume()
|
332 |
+
|
333 |
+
# Marching cubes
|
334 |
+
verts, faces, norms, vals = measure.marching_cubes(tsdf_vol, level=0, method='lewiner')
|
335 |
+
verts_ind = np.round(verts).astype(int)
|
336 |
+
verts = verts*self._voxel_size+self._vol_origin # voxel grid coordinates to world coordinates
|
337 |
+
|
338 |
+
# Get vertex colors
|
339 |
+
rgb_vals = color_vol[verts_ind[:,0], verts_ind[:,1], verts_ind[:,2]]
|
340 |
+
colors_b = np.floor(rgb_vals/self._color_const)
|
341 |
+
colors_g = np.floor((rgb_vals-colors_b*self._color_const)/256)
|
342 |
+
colors_r = rgb_vals-colors_b*self._color_const-colors_g*256
|
343 |
+
colors = np.floor(np.asarray([colors_r,colors_g,colors_b])).T
|
344 |
+
colors = colors.astype(np.uint8)
|
345 |
+
return verts, faces, norms, colors
|
346 |
+
|
347 |
+
|
348 |
+
def rigid_transform(xyz, transform):
|
349 |
+
"""Applies a rigid transform to an (N, 3) pointcloud.
|
350 |
+
"""
|
351 |
+
xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
|
352 |
+
xyz_t_h = np.dot(transform, xyz_h.T).T
|
353 |
+
return xyz_t_h[:, :3]
|
354 |
+
|
355 |
+
|
356 |
+
def get_view_frustum(depth_im, cam_intr, cam_pose):
|
357 |
+
"""Get corners of 3D camera view frustum of depth image
|
358 |
+
"""
|
359 |
+
im_h = depth_im.shape[0]
|
360 |
+
im_w = depth_im.shape[1]
|
361 |
+
max_depth = np.max(depth_im)
|
362 |
+
view_frust_pts = np.array([
|
363 |
+
(np.array([0,0,0,im_w,im_w])-cam_intr[0,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[0,0],
|
364 |
+
(np.array([0,0,im_h,0,im_h])-cam_intr[1,2])*np.array([0,max_depth,max_depth,max_depth,max_depth])/cam_intr[1,1],
|
365 |
+
np.array([0,max_depth,max_depth,max_depth,max_depth])
|
366 |
+
])
|
367 |
+
view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T
|
368 |
+
return view_frust_pts
|
369 |
+
|
370 |
+
|
371 |
+
def meshwrite(filename, verts, faces, norms, colors):
|
372 |
+
"""Save a 3D mesh to a polygon .ply file.
|
373 |
+
"""
|
374 |
+
# Write header
|
375 |
+
ply_file = open(filename,'w')
|
376 |
+
ply_file.write("ply\n")
|
377 |
+
ply_file.write("format ascii 1.0\n")
|
378 |
+
ply_file.write("element vertex %d\n"%(verts.shape[0]))
|
379 |
+
ply_file.write("property float x\n")
|
380 |
+
ply_file.write("property float y\n")
|
381 |
+
ply_file.write("property float z\n")
|
382 |
+
ply_file.write("property float nx\n")
|
383 |
+
ply_file.write("property float ny\n")
|
384 |
+
ply_file.write("property float nz\n")
|
385 |
+
ply_file.write("property uchar red\n")
|
386 |
+
ply_file.write("property uchar green\n")
|
387 |
+
ply_file.write("property uchar blue\n")
|
388 |
+
ply_file.write("element face %d\n"%(faces.shape[0]))
|
389 |
+
ply_file.write("property list uchar int vertex_index\n")
|
390 |
+
ply_file.write("end_header\n")
|
391 |
+
|
392 |
+
# Write vertex list
|
393 |
+
for i in range(verts.shape[0]):
|
394 |
+
ply_file.write("%f %f %f %f %f %f %d %d %d\n"%(
|
395 |
+
verts[i,0], verts[i,1], verts[i,2],
|
396 |
+
norms[i,0], norms[i,1], norms[i,2],
|
397 |
+
colors[i,0], colors[i,1], colors[i,2],
|
398 |
+
))
|
399 |
+
|
400 |
+
# Write face list
|
401 |
+
for i in range(faces.shape[0]):
|
402 |
+
ply_file.write("3 %d %d %d\n"%(faces[i,0], faces[i,1], faces[i,2]))
|
403 |
+
|
404 |
+
ply_file.close()
|
405 |
+
|
406 |
+
|
407 |
+
def pcwrite(filename, xyzrgb):
|
408 |
+
"""Save a point cloud to a polygon .ply file.
|
409 |
+
"""
|
410 |
+
xyz = xyzrgb[:, :3]
|
411 |
+
rgb = xyzrgb[:, 3:].astype(np.uint8)
|
412 |
+
|
413 |
+
# Write header
|
414 |
+
ply_file = open(filename,'w')
|
415 |
+
ply_file.write("ply\n")
|
416 |
+
ply_file.write("format ascii 1.0\n")
|
417 |
+
ply_file.write("element vertex %d\n"%(xyz.shape[0]))
|
418 |
+
ply_file.write("property float x\n")
|
419 |
+
ply_file.write("property float y\n")
|
420 |
+
ply_file.write("property float z\n")
|
421 |
+
ply_file.write("property uchar red\n")
|
422 |
+
ply_file.write("property uchar green\n")
|
423 |
+
ply_file.write("property uchar blue\n")
|
424 |
+
ply_file.write("end_header\n")
|
425 |
+
|
426 |
+
# Write vertex list
|
427 |
+
for i in range(xyz.shape[0]):
|
428 |
+
ply_file.write("%f %f %f %d %d %d\n"%(
|
429 |
+
xyz[i, 0], xyz[i, 1], xyz[i, 2],
|
430 |
+
rgb[i, 0], rgb[i, 1], rgb[i, 2],
|
431 |
+
))
|
432 |
+
|
433 |
+
def get_vol_bds(pred_depths : torch.Tensor, pred_c2ws : torch.Tensor, pred_intr : torch.Tensor):
|
434 |
+
n_views = pred_depths.shape[0]
|
435 |
+
vol_bnds = np.zeros((3,2))
|
436 |
+
|
437 |
+
for i in range(n_views):
|
438 |
+
intr = pred_intr[i].cpu().numpy()
|
439 |
+
c2w = pred_c2ws[i].cpu().numpy()
|
440 |
+
depth = pred_depths[i].cpu().numpy()
|
441 |
+
view_frust_pts = get_view_frustum(depth, intr, c2w)
|
442 |
+
vol_bnds[:,0] = np.minimum(vol_bnds[:,0], np.amin(view_frust_pts, axis=1))
|
443 |
+
vol_bnds[:,1] = np.maximum(vol_bnds[:,1], np.amax(view_frust_pts, axis=1))
|
444 |
+
|
445 |
+
return vol_bnds
|
446 |
+
|
447 |
+
def fuse_batch(pred_dict: dict, gt_dict: dict, batch:dict,voxel_size: float = 0.02):
|
448 |
+
pred_depths = pred_dict['pointmaps'][...,-1] # depth here is just z, assuming the predicted point map is in camera frame
|
449 |
+
pred_c2ws = batch['new_cams']['c2ws']
|
450 |
+
pred_intr = batch['new_cams']['Ks']
|
451 |
+
pred_masks = batch['new_cams']['valid_masks']
|
452 |
+
B = pred_depths.shape[0]
|
453 |
+
n_views = pred_depths.shape[1]
|
454 |
+
|
455 |
+
meshes = []
|
456 |
+
for i in range(B):
|
457 |
+
intrs = pred_intr[i]
|
458 |
+
c2ws = pred_c2ws[i]
|
459 |
+
depths = pred_depths[i]
|
460 |
+
vol_bnds = get_vol_bds(depths, c2ws, intrs)
|
461 |
+
tsdf_vol = TSDFVolume(vol_bnds, voxel_size=voxel_size)
|
462 |
+
masks = pred_masks[i]
|
463 |
+
|
464 |
+
for j in range(n_views):
|
465 |
+
intr = intrs[j]
|
466 |
+
c2w = c2ws[j]
|
467 |
+
depth = depths[j]
|
468 |
+
mask = masks[j]
|
469 |
+
depth[~mask] = 0
|
470 |
+
img = torch.zeros_like(depth,dtype=torch.uint8).unsqueeze(-1).repeat(1,1,3)
|
471 |
+
img[:,:,-1] = 255
|
472 |
+
tsdf_vol.integrate(img.cpu().numpy(), depth.cpu().numpy(), intr.cpu().numpy(), c2w.cpu().numpy(), obs_weight=1.)
|
473 |
+
|
474 |
+
verts, faces, norms, colors = tsdf_vol.get_mesh()
|
475 |
+
meshes.append(dict(verts=verts, faces=faces, norms=norms, colors=colors))
|
476 |
+
return meshes
|
utils/geometry.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import copy
|
4 |
+
from utils.utils import invalid_to_nans, invalid_to_zeros
|
5 |
+
|
6 |
+
def compute_pointmap(depth, cam2w, intrinsics):
|
7 |
+
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
8 |
+
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
9 |
+
h, w = depth.shape
|
10 |
+
|
11 |
+
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
12 |
+
|
13 |
+
x_cam = (i - cx) * depth / fx
|
14 |
+
y_cam = (j - cy) * depth / fy
|
15 |
+
|
16 |
+
points_cam = np.stack([x_cam, y_cam, depth], axis=-1)
|
17 |
+
points_world = np.dot(cam2w[:3, :3], points_cam.reshape(-1, 3).T).T + cam2w[:3, 3]
|
18 |
+
points_world = points_world.reshape(h, w, 3)
|
19 |
+
|
20 |
+
return points_world
|
21 |
+
|
22 |
+
def invert_poses(raw_poses):
|
23 |
+
poses = copy.deepcopy(raw_poses)
|
24 |
+
original_shape = poses.shape
|
25 |
+
poses = poses.reshape(-1, 4, 4)
|
26 |
+
R = copy.deepcopy(poses[:, :3, :3])
|
27 |
+
t = copy.deepcopy(poses[:, :3, 3])
|
28 |
+
poses[:, :3, :3] = R.transpose(1, 2)
|
29 |
+
poses[:, :3, 3] = torch.bmm(-R.transpose(1, 2), t.unsqueeze(-1)).squeeze(-1)
|
30 |
+
poses = poses.reshape(*original_shape)
|
31 |
+
return poses
|
32 |
+
|
33 |
+
def center_pointmaps_set(dict,w2cs):
|
34 |
+
swap_dim = False
|
35 |
+
if dict["pointmaps"].shape[1] == 3:
|
36 |
+
swap_dim = True
|
37 |
+
dict["pointmaps"] = dict["pointmaps"].transpose(1,-1)
|
38 |
+
|
39 |
+
original_shape = dict["pointmaps"].shape
|
40 |
+
device = dict["pointmaps"].device
|
41 |
+
B = original_shape[0]
|
42 |
+
|
43 |
+
# recompute pointmaps in camera frame
|
44 |
+
pointmaps = dict["pointmaps"]
|
45 |
+
pointmaps_h = torch.cat([pointmaps,torch.ones(pointmaps.shape[:-1]+(1,)).to(device)],dim=-1)
|
46 |
+
pointmaps_h = pointmaps_h.reshape(B,-1,4)
|
47 |
+
pointmaps_recentered_h = torch.bmm(w2cs,pointmaps_h.transpose(1,2)).transpose(1,2)
|
48 |
+
pointmaps_recentered = pointmaps_recentered_h[...,:3]/pointmaps_recentered_h[...,3:4]
|
49 |
+
pointmaps_recentered = pointmaps_recentered.reshape(*original_shape)
|
50 |
+
|
51 |
+
# recompute c2ws
|
52 |
+
if "c2ws" in dict:
|
53 |
+
c2ws_recentered = torch.bmm(w2cs,dict["c2ws"].reshape(-1,4,4))
|
54 |
+
c2ws_recentered = c2ws_recentered.reshape(dict["c2ws"].shape)
|
55 |
+
dict["c2ws"] = c2ws_recentered
|
56 |
+
|
57 |
+
# assign to dict
|
58 |
+
dict["pointmaps"] = pointmaps_recentered
|
59 |
+
if swap_dim:
|
60 |
+
dict["pointmaps"] = dict["pointmaps"].transpose(1,-1)
|
61 |
+
return dict
|
62 |
+
|
63 |
+
def center_pointmaps(batch):
|
64 |
+
original_poses = batch["new_cams"]["c2ws"] # assuming first camera is the one we want to predict
|
65 |
+
w2cs = invert_poses(batch["new_cams"]["c2ws"])
|
66 |
+
|
67 |
+
batch["new_cams"] = center_pointmaps_set(batch["new_cams"],w2cs)
|
68 |
+
batch["input_cams"] = center_pointmaps_set(batch["input_cams"],w2cs)
|
69 |
+
batch["original_poses"] = original_poses
|
70 |
+
return batch
|
71 |
+
|
72 |
+
|
73 |
+
def uncenter_pointmaps(pred,gt,batch):
|
74 |
+
original_poses = batch["original_poses"]
|
75 |
+
|
76 |
+
batch["new_cams"] = center_pointmaps_set(batch["new_cams"],original_poses)
|
77 |
+
batch["input_cams"] = center_pointmaps_set(batch["input_cams"],original_poses)
|
78 |
+
|
79 |
+
#gt = center_pointmaps_set(gt,original_poses)
|
80 |
+
#pred = center_pointmaps_set(pred,original_poses)
|
81 |
+
return pred, gt, batch
|
82 |
+
|
83 |
+
def compute_rays(batch):
|
84 |
+
h, w = batch["new_cams"]["pointmaps"].shape[-3:-1]
|
85 |
+
B = batch["new_cams"]["pointmaps"].shape[0]
|
86 |
+
device = batch["new_cams"]["pointmaps"].device
|
87 |
+
Ks = batch["new_cams"]["Ks"]
|
88 |
+
i_s, j_s = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
89 |
+
i_s, j_s = torch.tensor(i_s).repeat(B,1,1).to(device), torch.tensor(j_s).repeat(B,1,1).to(device)
|
90 |
+
|
91 |
+
f_x = Ks[:,0,0].reshape(-1,1,1)
|
92 |
+
f_y = Ks[:,1,1].reshape(-1,1,1)
|
93 |
+
c_x = Ks[:,0,2].reshape(-1,1,1)
|
94 |
+
c_y = Ks[:,1,2].reshape(-1,1,1)
|
95 |
+
|
96 |
+
# compute rays with z=1
|
97 |
+
x_cam = (i_s - c_x) / f_x
|
98 |
+
y_cam = (j_s - c_y) / f_y
|
99 |
+
rays = torch.cat([x_cam.unsqueeze(-1),y_cam.unsqueeze(-1)],dim=-1)
|
100 |
+
return rays
|
101 |
+
|
102 |
+
def normalize_pointcloud(pts1, pts2=None, norm_mode='avg_dis', valid1=None, valid2=None, valid3=None, ret_factor=False,pts3=None):
|
103 |
+
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
|
104 |
+
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
|
105 |
+
norm_mode, dis_mode = norm_mode.split('_')
|
106 |
+
|
107 |
+
if norm_mode == 'avg':
|
108 |
+
# gather all points together (joint normalization)
|
109 |
+
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
|
110 |
+
nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
|
111 |
+
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
112 |
+
if pts3 is not None:
|
113 |
+
nan_pts3, nnz3 = invalid_to_zeros(pts3, valid3, ndim=3)
|
114 |
+
all_pts = torch.cat((all_pts, nan_pts3), dim=1)
|
115 |
+
nnz1 += nnz3
|
116 |
+
# compute distance to origin
|
117 |
+
all_dis = all_pts.norm(dim=-1)
|
118 |
+
if dis_mode == 'dis':
|
119 |
+
pass # do nothing
|
120 |
+
elif dis_mode == 'log1p':
|
121 |
+
all_dis = torch.log1p(all_dis)
|
122 |
+
elif dis_mode == 'warp-log1p':
|
123 |
+
# actually warp input points before normalizing them
|
124 |
+
log_dis = torch.log1p(all_dis)
|
125 |
+
warp_factor = log_dis / all_dis.clip(min=1e-8)
|
126 |
+
H1, W1 = pts1.shape[1:-1]
|
127 |
+
pts1 = pts1 * warp_factor[:,:W1*H1].view(-1,H1,W1,1)
|
128 |
+
if pts2 is not None:
|
129 |
+
H2, W2 = pts2.shape[1:-1]
|
130 |
+
pts2 = pts2 * warp_factor[:,W1*H1:].view(-1,H2,W2,1)
|
131 |
+
all_dis = log_dis # this is their true distance afterwards
|
132 |
+
else:
|
133 |
+
raise ValueError(f'bad {dis_mode=}')
|
134 |
+
|
135 |
+
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
|
136 |
+
else:
|
137 |
+
# gather all points together (joint normalization)
|
138 |
+
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
|
139 |
+
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
|
140 |
+
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
141 |
+
|
142 |
+
# compute distance to origin
|
143 |
+
all_dis = all_pts.norm(dim=-1)
|
144 |
+
|
145 |
+
if norm_mode == 'avg':
|
146 |
+
norm_factor = all_dis.nanmean(dim=1)
|
147 |
+
elif norm_mode == 'median':
|
148 |
+
norm_factor = all_dis.nanmedian(dim=1).values.detach()
|
149 |
+
elif norm_mode == 'sqrt':
|
150 |
+
norm_factor = all_dis.sqrt().nanmean(dim=1)**2
|
151 |
+
else:
|
152 |
+
raise ValueError(f'bad {norm_mode=}')
|
153 |
+
|
154 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
155 |
+
while norm_factor.ndim < pts1.ndim:
|
156 |
+
norm_factor.unsqueeze_(-1)
|
157 |
+
|
158 |
+
res = (pts1 / norm_factor,)
|
159 |
+
if pts2 is not None:
|
160 |
+
res = res + (pts2 / norm_factor,)
|
161 |
+
if pts3 is not None:
|
162 |
+
res = res + (pts3 / norm_factor,)
|
163 |
+
if ret_factor:
|
164 |
+
res = res + (norm_factor,)
|
165 |
+
return res
|
166 |
+
|
167 |
+
def compute_pointmap_torch(depth, cam2w, intrinsics,device='cuda'):
|
168 |
+
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
169 |
+
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
170 |
+
h, w = depth.shape
|
171 |
+
|
172 |
+
#i, j = np.meshgrid(np.arange(w), np.arange(h), indexing='xy')
|
173 |
+
i, j = torch.meshgrid(torch.arange(w).to(device), torch.arange(h).to(device), indexing='xy')
|
174 |
+
x_cam = (i - cx) * depth / fx
|
175 |
+
y_cam = (j - cy) * depth / fy
|
176 |
+
|
177 |
+
points_cam = torch.stack([x_cam, y_cam, depth], dim=-1)
|
178 |
+
points_world = (cam2w[:3, :3] @ points_cam.reshape(-1, 3).T).T + cam2w[:3, 3]
|
179 |
+
points_world = points_world.reshape(h, w, 3)
|
180 |
+
|
181 |
+
return points_world
|
182 |
+
|
183 |
+
def depth2pts(depths, Ks):
|
184 |
+
"""
|
185 |
+
Convert depth map to 3D points
|
186 |
+
"""
|
187 |
+
device = depths.device
|
188 |
+
B = depths.shape[0]
|
189 |
+
pts = []
|
190 |
+
for b in range(B):
|
191 |
+
depth_b = depths[b]
|
192 |
+
K = Ks[b]
|
193 |
+
pts.append(compute_pointmap_torch(depth_b,torch.eye(4).to(device), K,device))
|
194 |
+
pts = torch.stack(pts, dim=0)
|
195 |
+
return pts
|
utils/misc.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
import socket
|
9 |
+
# source: https://github.com/LTH14/mar/blob/main/util/misc.py
|
10 |
+
|
11 |
+
def prep_torch():
|
12 |
+
cpu_cores = get_cpu_cores()
|
13 |
+
torch.set_num_threads(cpu_cores) # intra-op threads (e.g., matrix ops)
|
14 |
+
torch.set_num_interop_threads(cpu_cores) # inter-op parallelism
|
15 |
+
|
16 |
+
os.environ["OMP_NUM_THREADS"] = str(cpu_cores)
|
17 |
+
os.environ["MKL_NUM_THREADS"] = str(cpu_cores)
|
18 |
+
os.environ["OPENBLAS_NUM_THREADS"] = str(cpu_cores)
|
19 |
+
|
20 |
+
def get_cpu_cores():
|
21 |
+
hostname = socket.gethostname()
|
22 |
+
if "bridges2" in hostname:
|
23 |
+
return int(os.environ["SLURM_JOB_CPUS_PER_NODE"])
|
24 |
+
else:
|
25 |
+
try:
|
26 |
+
with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us", "r") as f:
|
27 |
+
quota = int(f.read().strip())
|
28 |
+
with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us", "r") as f:
|
29 |
+
period = int(f.read().strip())
|
30 |
+
if quota > 0:
|
31 |
+
return max(1, quota // period)
|
32 |
+
except Exception as e:
|
33 |
+
return os.cpu_count()
|
34 |
+
|
35 |
+
def setup_distributed():
|
36 |
+
dist.init_process_group(backend='nccl')
|
37 |
+
# Get the rank of the current process
|
38 |
+
rank = int(os.environ.get('RANK'))
|
39 |
+
world_size = int(os.environ.get('WORLD_SIZE'))
|
40 |
+
local_rank = int(os.environ.get('LOCAL_RANK'))
|
41 |
+
torch.cuda.set_device(local_rank)
|
42 |
+
return rank, world_size, local_rank
|
43 |
+
|
44 |
+
def is_dist_avail_and_initialized():
|
45 |
+
if not dist.is_available():
|
46 |
+
return False
|
47 |
+
if not dist.is_initialized():
|
48 |
+
return False
|
49 |
+
return True
|
50 |
+
|
51 |
+
def get_rank():
|
52 |
+
if not is_dist_avail_and_initialized():
|
53 |
+
return 0
|
54 |
+
return dist.get_rank()
|
55 |
+
|
56 |
+
def is_main_process():
|
57 |
+
return get_rank() == 0
|
58 |
+
|
59 |
+
def get_world_size():
|
60 |
+
if not is_dist_avail_and_initialized():
|
61 |
+
return 1
|
62 |
+
return dist.get_world_size()
|
63 |
+
|
64 |
+
def save_on_master(*args, **kwargs):
|
65 |
+
if is_main_process():
|
66 |
+
torch.save(*args, **kwargs)
|
67 |
+
|
68 |
+
def save_model(args, epoch, model, optimizer, ema_params=None, epoch_name=None):
|
69 |
+
if epoch_name is None:
|
70 |
+
epoch_name = str(epoch)
|
71 |
+
|
72 |
+
output_dir = Path(args.logdir)
|
73 |
+
checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
|
74 |
+
|
75 |
+
if ema_params is not None:
|
76 |
+
ema_state_dict = copy.deepcopy(model.state_dict())
|
77 |
+
for i, (name, _value) in enumerate(model.named_parameters()):
|
78 |
+
assert name in ema_state_dict
|
79 |
+
ema_state_dict[name] = ema_params[i]
|
80 |
+
else:
|
81 |
+
ema_state_dict = None
|
82 |
+
|
83 |
+
to_save = {
|
84 |
+
'model': model.state_dict(),
|
85 |
+
'optimizer': optimizer.state_dict(),
|
86 |
+
'epoch': epoch,
|
87 |
+
'args': args,
|
88 |
+
'model_ema': ema_state_dict,
|
89 |
+
}
|
90 |
+
|
91 |
+
save_on_master(to_save, checkpoint_path)
|
92 |
+
|
93 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
94 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
95 |
+
if epoch < args.warmup_epochs:
|
96 |
+
lr = args.lr * epoch / args.warmup_epochs
|
97 |
+
else:
|
98 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
99 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.n_epochs - args.warmup_epochs)))
|
100 |
+
for param_group in optimizer.param_groups:
|
101 |
+
if "lr_scale" in param_group:
|
102 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
103 |
+
else:
|
104 |
+
param_group["lr"] = lr
|
105 |
+
|
106 |
+
return lr
|
107 |
+
|
108 |
+
|
109 |
+
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
110 |
+
decay = []
|
111 |
+
no_decay = []
|
112 |
+
for name, param in model.named_parameters():
|
113 |
+
if not param.requires_grad:
|
114 |
+
continue # frozen weights
|
115 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
|
116 |
+
no_decay.append(param) # no weight decay on bias, norm and diffloss
|
117 |
+
else:
|
118 |
+
decay.append(param)
|
119 |
+
return [
|
120 |
+
{'params': no_decay, 'weight_decay': 0.},
|
121 |
+
{'params': decay, 'weight_decay': weight_decay}]
|
122 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def to_tensor(x,dtype=torch.float64):
|
5 |
+
if isinstance(x, torch.Tensor):
|
6 |
+
return x.to(dtype)
|
7 |
+
elif isinstance(x, np.ndarray):
|
8 |
+
return torch.from_numpy(x.copy()).to(dtype)
|
9 |
+
else:
|
10 |
+
raise ValueError(f"Unsupported type: {type(x)}")
|
11 |
+
|
12 |
+
def to_numpy(x):
|
13 |
+
if isinstance(x, torch.Tensor):
|
14 |
+
return x.detach().cpu().numpy()
|
15 |
+
elif isinstance(x, np.ndarray):
|
16 |
+
return x
|
17 |
+
else:
|
18 |
+
raise ValueError(f"Unsupported type: {type(x)}")
|
19 |
+
|
20 |
+
def invalid_to_nans( arr, valid_mask, ndim=999 ):
|
21 |
+
if valid_mask is not None:
|
22 |
+
arr = arr.clone()
|
23 |
+
arr[~valid_mask] = float('nan')
|
24 |
+
if arr.ndim > ndim:
|
25 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
26 |
+
return arr
|
27 |
+
|
28 |
+
def invalid_to_zeros( arr, valid_mask, ndim=999 ):
|
29 |
+
if valid_mask is not None:
|
30 |
+
arr = arr.clone()
|
31 |
+
arr[~valid_mask] = 0
|
32 |
+
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
33 |
+
else:
|
34 |
+
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
35 |
+
if arr.ndim > ndim:
|
36 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
37 |
+
return arr, nnz
|
38 |
+
|
39 |
+
def scenes_to_batch(scenes,repeat=None):
|
40 |
+
batch = {}
|
41 |
+
n_cams = None
|
42 |
+
|
43 |
+
if 'new_cams' in scenes:
|
44 |
+
n_cams = scenes['new_cams']['depths'].shape[1]
|
45 |
+
batch['new_cams'], n_cams = scenes_to_batch(scenes['new_cams'])
|
46 |
+
batch['input_cams'],_ = scenes_to_batch(scenes['input_cams'],repeat=n_cams)
|
47 |
+
else:
|
48 |
+
for key in scenes.keys():
|
49 |
+
shape = scenes[key].shape
|
50 |
+
if len(shape) > 3 :
|
51 |
+
n_cams = shape[1]
|
52 |
+
if repeat is not None:
|
53 |
+
# repeat the 2nd dimension by repeat times to also have the inputs repeated in the batch
|
54 |
+
repeat_dims = (1,) * len(shape) # (1,1,1,...) for all dimensions
|
55 |
+
repeat_dims = list(repeat_dims)
|
56 |
+
repeat_dims[1] = repeat
|
57 |
+
batch[key] = scenes[key].repeat(*repeat_dims)
|
58 |
+
batch[key] = batch[key].reshape(-1, *shape[2:])
|
59 |
+
else:
|
60 |
+
batch[key] = scenes[key].reshape(-1, *shape[2:])
|
61 |
+
elif key == 'dino_features':
|
62 |
+
repeat_shape = (repeat,) + (1,) * (len(shape) - 1)
|
63 |
+
batch[key] = scenes[key].repeat(*repeat_shape)
|
64 |
+
else:
|
65 |
+
batch[key] = scenes[key]
|
66 |
+
return batch, n_cams
|
67 |
+
|
68 |
+
def dict_to_scenes(input_dict,n_cams):
|
69 |
+
scenes = {}
|
70 |
+
for key in input_dict.keys():
|
71 |
+
if isinstance(input_dict[key],dict):
|
72 |
+
scenes[key] = dict_to_scenes(input_dict[key],n_cams)
|
73 |
+
else:
|
74 |
+
scenes[key] = input_dict[key].reshape(-1, n_cams, *input_dict[key].shape[1:])
|
75 |
+
return scenes
|
76 |
+
|
77 |
+
def batch_to_scenes(pred,gt,batch,n_cams):
|
78 |
+
# pred
|
79 |
+
batch = dict_to_scenes(batch,n_cams)
|
80 |
+
pred = dict_to_scenes(pred,n_cams)
|
81 |
+
gt = dict_to_scenes(gt,n_cams)
|
82 |
+
return pred, gt, batch
|
utils/viz.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bb = breakpoint
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from utils.utils import to_tensor, to_numpy
|
5 |
+
import open3d as o3d
|
6 |
+
import rerun as rr
|
7 |
+
|
8 |
+
OPENCV2OPENGL = (1,-1,-1,1)
|
9 |
+
|
10 |
+
def pts_to_opengl(pts):
|
11 |
+
return pts*OPENCV2OPENGL[:3]
|
12 |
+
|
13 |
+
def save_pointmaps(data,path='debug',view=False,color='novelty',frustrum_scale=20):
|
14 |
+
# debug function to save points to a ply file
|
15 |
+
import open3d as o3d
|
16 |
+
pointmaps = data['pointmaps']
|
17 |
+
B = pointmaps.shape[0]
|
18 |
+
W, H = pointmaps.shape[-3:-1]
|
19 |
+
n_cams = data['c2ws'].shape[1]
|
20 |
+
geometries = []
|
21 |
+
for b in range(B):
|
22 |
+
geometry_b = []
|
23 |
+
points = torch.cat([p.flatten(start_dim=0,end_dim=1) for p in pointmaps[b]],dim=0)
|
24 |
+
if view:
|
25 |
+
pcd = o3d.geometry.PointCloud()
|
26 |
+
pcd.points = o3d.utility.Vector3dVector(to_numpy(points))
|
27 |
+
if color == 'novelty':
|
28 |
+
colors = torch.ones_like(points)
|
29 |
+
pts_p_cam = W*H
|
30 |
+
# make all novel points red
|
31 |
+
colors[pts_p_cam:,1:]*=0.1
|
32 |
+
|
33 |
+
# make all points from first camera blue
|
34 |
+
colors[:pts_p_cam,0]*=0.1
|
35 |
+
colors[:pts_p_cam,2]*=0.1
|
36 |
+
colors*=255.0
|
37 |
+
|
38 |
+
else:
|
39 |
+
colors = torch.cat([p.flatten(start_dim=0,end_dim=1) for p in data['imgs'][b]],dim=0)
|
40 |
+
pcd.colors = o3d.utility.Vector3dVector(to_numpy(colors)/255.0)
|
41 |
+
geometry_b.append(pcd)
|
42 |
+
origin = o3d.geometry.TriangleMesh.create_coordinate_frame(
|
43 |
+
size=10, origin=[0,0,0])
|
44 |
+
geometry_b.append(origin)
|
45 |
+
for i in range(n_cams):
|
46 |
+
K = data['Ks'][b,i].cpu().numpy()
|
47 |
+
K = o3d.camera.PinholeCameraIntrinsic(W,H,K)
|
48 |
+
P = data['c2ws'][b,i].cpu().numpy()
|
49 |
+
cam_frame = o3d.geometry.LineSet.create_camera_visualization(intrinsic=K,extrinsic=P,scale=frustrum_scale)
|
50 |
+
geometry_b.append(cam_frame)
|
51 |
+
o3d.visualization.draw_geometries(geometry_b)
|
52 |
+
|
53 |
+
# add point at the origin
|
54 |
+
o3d.io.write_point_cloud(f"{path}_{b}.ply", pcd)
|
55 |
+
breakpoint()
|
56 |
+
geometries.append(geometry_b)
|
57 |
+
return geometries
|
58 |
+
|
59 |
+
def just_load_viz(pred_dict,gt_dict,batch,name='just_load_viz',addr='localhost:9000',fused_meshes=None,n_points=None):
|
60 |
+
rr.init(name)
|
61 |
+
rr.connect(addr)
|
62 |
+
rr.set_time_seconds("stable_time", 0)
|
63 |
+
|
64 |
+
context_views = batch['input_cams']['pointmaps']
|
65 |
+
context_rgbs = batch['input_cams']['imgs']
|
66 |
+
gt_pred_views = gt_dict['pointmaps']
|
67 |
+
pred_views = pred_dict['pointmaps']
|
68 |
+
|
69 |
+
# FIX this weird shape
|
70 |
+
pred_masks = batch['new_cams']['valid_masks']
|
71 |
+
context_masks = batch['input_cams']['valid_masks']
|
72 |
+
|
73 |
+
B = batch['new_cams']['pointmaps'].shape[0]
|
74 |
+
W,H = context_views.shape[-3:-1]
|
75 |
+
n_pred_cams = pred_views.shape[1]
|
76 |
+
|
77 |
+
for b in range(B):
|
78 |
+
rr.set_time_seconds("stable_time", b)
|
79 |
+
# Set world transform to identity (normal origin)
|
80 |
+
rr.log("world", rr.Transform3D(translation=[0, 0, 0], mat3x3=np.eye(3)))
|
81 |
+
## show context views
|
82 |
+
context_rgb = to_numpy(context_rgbs[b])
|
83 |
+
|
84 |
+
for i in range(n_pred_cams):
|
85 |
+
if 'conf_pointmaps' in pred_dict:
|
86 |
+
conf_pts = pred_dict['conf_pointmaps'][b,i]
|
87 |
+
|
88 |
+
#print(f"view {i} mean conf: {mean_conf}, std conf: {std_conf}")
|
89 |
+
conf_pts = (conf_pts - conf_pts.min())/(conf_pts.max() - conf_pts.min())
|
90 |
+
conf_pts = to_numpy(conf_pts)
|
91 |
+
rr.log(f"view_{i}/pred_conf", rr.Image(conf_pts))
|
92 |
+
if pred_masks[b,i].sum() == 0:
|
93 |
+
continue
|
94 |
+
if gt_pred_views is not None:
|
95 |
+
gt_pred_pts = gt_pred_views[b,i][pred_masks[b,i]]
|
96 |
+
gt_pred_pts = to_numpy(gt_pred_pts)
|
97 |
+
else:
|
98 |
+
gt_pred_pts = None
|
99 |
+
|
100 |
+
# red is color for gt points
|
101 |
+
if gt_pred_pts is not None:
|
102 |
+
color = np.array([1,0,0])
|
103 |
+
colors = np.ones_like(gt_pred_pts)
|
104 |
+
colors[:,0] = color[0]
|
105 |
+
colors[:,1] = color[1]
|
106 |
+
colors[:,2] = color[2]
|
107 |
+
rr.log(
|
108 |
+
f"world/new_views_gt/view_{i}", rr.Points3D(gt_pred_pts,colors=colors)
|
109 |
+
)
|
110 |
+
# green is color for pred points
|
111 |
+
pred_pts = pred_views[b,i][pred_masks[b,i]]
|
112 |
+
pred_pts = to_numpy(pred_pts)
|
113 |
+
|
114 |
+
depth = pred_views[b,i][:,:,2]
|
115 |
+
depth -= depth[pred_masks[b,i]].min()
|
116 |
+
depth[~pred_masks[b,i]] = 0
|
117 |
+
depth /= depth.max()
|
118 |
+
depth = to_numpy(depth)
|
119 |
+
rr.log(f"world/new_views_pred/view_{i}/image", rr.Image(depth))
|
120 |
+
|
121 |
+
if 'classifier' in pred_dict:
|
122 |
+
classifier = (pred_dict['classifier'][b,i] > 0.0).float() # this is assuming the classifier is a sigmoid output
|
123 |
+
classifier = to_numpy(classifier)
|
124 |
+
rr.log(f"view_{i}/pred_mask", rr.Image(classifier))
|
125 |
+
|
126 |
+
color = np.array([0,1,0])
|
127 |
+
colors = np.ones_like(pred_pts)
|
128 |
+
colors[:,0] = color[0]
|
129 |
+
colors[:,1] = color[1]
|
130 |
+
colors[:,2] = color[2]
|
131 |
+
if n_points is None:
|
132 |
+
rr.log(
|
133 |
+
f"world/new_views_pred/view_{i}/pred_points", rr.Points3D(pred_pts,colors=colors)
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
# randomly sample n_points from pred_pts
|
137 |
+
n_points = min(n_points, pred_pts.shape[0])
|
138 |
+
inds = np.random.choice(pred_pts.shape[0], n_points, replace=False)
|
139 |
+
rr.log(
|
140 |
+
f"world/new_views_pred/view_{i}/pred_points", rr.Points3D(pred_pts[inds],colors=colors[inds])
|
141 |
+
)
|
142 |
+
|
143 |
+
K = batch['new_cams']['Ks'][b,i].cpu().numpy()
|
144 |
+
P = batch['new_cams']['c2ws'][b,i].cpu().numpy()
|
145 |
+
P = np.linalg.inv(P)
|
146 |
+
rr.log(f"world/new_views_pred/view_{i}", rr.Transform3D(translation=P[:3,3], mat3x3=P[:3,:3], from_parent=True))
|
147 |
+
|
148 |
+
rr.log(f"world/new_views_gt/view_{i}", rr.Transform3D(translation=P[:3,3], mat3x3=P[:3,:3], from_parent=True))
|
149 |
+
|
150 |
+
if 'classifier' in pred_dict:
|
151 |
+
classifier = gt_dict['valid_masks'][b,i].float()
|
152 |
+
classifier = to_numpy(classifier)
|
153 |
+
rr.log(f"view_{i}/gt_mask", rr.Image(classifier))
|
154 |
+
|
155 |
+
rr.log(
|
156 |
+
f"world/new_views_pred/view_{i}/image",
|
157 |
+
rr.Pinhole(
|
158 |
+
resolution=[H, W],
|
159 |
+
focal_length=[K[0,0], K[1,1]],
|
160 |
+
principal_point=[K[0,2], K[1,2]],
|
161 |
+
),
|
162 |
+
)
|
163 |
+
|
164 |
+
rr.log(f"world/new_views_pred/view_{i}/image", rr.Image(to_numpy(pred_masks[b,i].float())))
|
165 |
+
n_input_cams = context_masks.shape[1]
|
166 |
+
|
167 |
+
for i in range(n_input_cams):
|
168 |
+
context_pts = context_views[b][i][context_masks[b][i]]
|
169 |
+
context_pts = to_numpy(context_pts)
|
170 |
+
context_pts_rgb = context_rgbs[b][i][context_masks[b][i]]
|
171 |
+
context_pts_rgb = to_numpy(context_pts_rgb)
|
172 |
+
|
173 |
+
# depth imgs
|
174 |
+
#context_depths = batch['input_cams']['depths'][b][i]
|
175 |
+
#context_depths = (context_depths / context_depths.max() * 255.0).clamp(0,255)
|
176 |
+
#context_depths = to_numpy(context_depths).astype(np.uint8)
|
177 |
+
rr.log(
|
178 |
+
f"world/context_views/view_{i}_points", rr.Points3D(context_pts,colors=(context_pts_rgb/255.0))
|
179 |
+
)
|
180 |
+
|
181 |
+
K = batch['input_cams']['Ks'][b,i].cpu().numpy()
|
182 |
+
P = batch['input_cams']['c2ws'][b,i].cpu().numpy()
|
183 |
+
P = np.linalg.inv(P)
|
184 |
+
rr.log(f"world/context_views/view_{i}", rr.Transform3D(translation=P[:3,3], mat3x3=P[:3,:3], from_parent=True))
|
185 |
+
|
186 |
+
rr.log(
|
187 |
+
f"world/context_views/view_{i}/image",
|
188 |
+
rr.Pinhole(
|
189 |
+
resolution=[H, W],
|
190 |
+
focal_length=[K[0,0], K[1,1]],
|
191 |
+
principal_point=[K[0,2], K[1,2]],
|
192 |
+
),
|
193 |
+
)
|
194 |
+
context_rgb_i = context_rgb[i]
|
195 |
+
rr.log(
|
196 |
+
f"world/context_views/view_{i}/image", rr.Image(context_rgb_i)
|
197 |
+
)
|
198 |
+
|
199 |
+
rr.log(
|
200 |
+
f"world/context_camera_{i}/mask", rr.Image(to_numpy(context_masks[b,i].float()))
|
201 |
+
)
|
202 |
+
if fused_meshes is not None:
|
203 |
+
rr.log(f"world/fused_mesh", rr.Mesh3D(vertex_positions=fused_meshes[b]['verts'], vertex_normals=fused_meshes[b]['norms'], vertex_colors=fused_meshes[b]['colors'], triangle_indices=fused_meshes[b]['faces']))
|
204 |
+
|
205 |
+
|
xps/train_rayst3r.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import socket
|
3 |
+
import os
|
4 |
+
# Add the current working directory to the Python path
|
5 |
+
current_dir = os.getcwd()
|
6 |
+
sys.path.append(current_dir)
|
7 |
+
from xps.util import *
|
8 |
+
|
9 |
+
root_log_dir = "logs"
|
10 |
+
n_views = 2
|
11 |
+
dataset_size = -1
|
12 |
+
|
13 |
+
imshape_input = (480,640)
|
14 |
+
imshape_output = (480,640)
|
15 |
+
render_size = (480,640)
|
16 |
+
|
17 |
+
preload_train = False
|
18 |
+
data_dirs = ["/home/jovyan/shared/bduister/data/processed/","/home/jovyan/shared/bduister/data-2/processed/"]
|
19 |
+
dino_features = [4,11,17,23]
|
20 |
+
datasets = ['fp_gso','octmae']
|
21 |
+
prefetch_dino = False
|
22 |
+
normalize_mode = 'median'
|
23 |
+
#start_from = "checkpoints/gso_conf.pth"
|
24 |
+
start_from = None
|
25 |
+
|
26 |
+
noise_std = 0.005
|
27 |
+
view_select_mode = "new_zoom"
|
28 |
+
rendered_views_mode = "always"
|
29 |
+
dataset_train = f"GenericLoader(size={dataset_size},seed=747,dir={repr(data_dirs)},split='train',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
|
30 |
+
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
|
31 |
+
dataset_test = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
|
32 |
+
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
|
33 |
+
dataset_just_load = f"GenericLoader(size=1000,seed=787,dir={repr(data_dirs)},split='test',datasets={datasets},mode='fast',prefetch_dino={prefetch_dino}," \
|
34 |
+
+f"dino_features={dino_features},view_select_mode='{view_select_mode}',noise_std={noise_std},rendered_views_mode='{rendered_views_mode}')"
|
35 |
+
|
36 |
+
augmentor = "Augmentor()"
|
37 |
+
|
38 |
+
patch_size = 16
|
39 |
+
save_every = 1
|
40 |
+
|
41 |
+
vit="base"
|
42 |
+
if vit == "debug":
|
43 |
+
enc_dim = 128
|
44 |
+
dec_dim = 128
|
45 |
+
n_heads = 4
|
46 |
+
enc_depth = 4
|
47 |
+
dec_depth = 4
|
48 |
+
head_n_layers = 1
|
49 |
+
head_dim = 128
|
50 |
+
lr = 3e-4
|
51 |
+
batch_size = 20
|
52 |
+
blr = 1.5e-4
|
53 |
+
elif vit == "debug_2":
|
54 |
+
enc_dim = 512
|
55 |
+
dec_dim = 512
|
56 |
+
n_heads = 4
|
57 |
+
enc_depth = 4
|
58 |
+
dec_depth = 10
|
59 |
+
head_n_layers = 1
|
60 |
+
head_dim = 128
|
61 |
+
blr = 1.5e-4
|
62 |
+
batch_size = 18
|
63 |
+
elif vit == "small":
|
64 |
+
enc_dim = 384
|
65 |
+
dec_dim = 384
|
66 |
+
n_heads = 6
|
67 |
+
enc_depth = 12
|
68 |
+
dec_depth = 12
|
69 |
+
head_n_layers = 1
|
70 |
+
head_dim = 128
|
71 |
+
batch_size = 6
|
72 |
+
blr = 1.5e-4
|
73 |
+
elif vit == "base":
|
74 |
+
enc_dim = 768
|
75 |
+
dec_dim = 768
|
76 |
+
n_heads = 12
|
77 |
+
enc_depth = 4
|
78 |
+
dec_depth = 12
|
79 |
+
head_n_layers = 1
|
80 |
+
head_dim = 128
|
81 |
+
batch_size = 10
|
82 |
+
blr = 1.5e-4
|
83 |
+
|
84 |
+
lambda_classifier = 0.1
|
85 |
+
for skip_conf_points in [False]:
|
86 |
+
skip_conf_mask = True
|
87 |
+
model = f"RayQuery(ray_enc=RayEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \
|
88 |
+
f"pointmap_enc=PointmapEncoder(dim={enc_dim},num_heads={n_heads},depth={enc_depth},img_size={render_size},patch_size={patch_size})," + \
|
89 |
+
f"dino_layers={dino_features}," + \
|
90 |
+
f"pts_head_type='dpt_depth'," + \
|
91 |
+
f"classifier_head_type='dpt_mask'," + \
|
92 |
+
f"decoder_dim={dec_dim},decoder_depth={dec_depth},decoder_num_heads={n_heads},imshape={render_size}," + \
|
93 |
+
f"criterion=DepthCompletion(ConfLoss(L21,skip_conf={skip_conf_points}),ConfLoss(ClassifierLoss(BCELoss()),skip_conf={skip_conf_mask}),lambda_classifier={lambda_classifier}),return_all_blocks=True)"
|
94 |
+
|
95 |
+
key = f"conf_points_{skip_conf_points==False}"
|
96 |
+
key = gen_key(key)
|
97 |
+
logdir = os.path.join(root_log_dir,key)
|
98 |
+
resume=logdir
|
99 |
+
wandb_run_name=key
|
100 |
+
os.makedirs(logdir,exist_ok=True)
|
101 |
+
|
102 |
+
n_epochs = 20
|
103 |
+
eval_every = 1
|
104 |
+
max_norm = -1
|
105 |
+
OMP_NUM_THREADS=16
|
106 |
+
warmup_epochs = 1
|
107 |
+
|
108 |
+
executable = f"OMP_NUM_THREADS={OMP_NUM_THREADS} torchrun --nnodes 1 --nproc_per_node $(python -c 'import torch; print(torch.cuda.device_count())') --master_port $((RANDOM%500+29000)) main.py"
|
109 |
+
#executable = f"python main.py"
|
110 |
+
if '--just_load' in sys.argv:
|
111 |
+
batch_size = 5
|
112 |
+
command = f"{executable} --{dataset_train=} --{dataset_test=} --{dataset_just_load=} --{logdir=} --{resume=} --{model=} --{batch_size=} --{normalize_mode=} --{augmentor=}"
|
113 |
+
else:
|
114 |
+
command = f"{executable} --{dataset_train=} --{dataset_test=} --{logdir=} --{n_epochs=} --{resume=} --{normalize_mode=} --{augmentor=} --{warmup_epochs=}"
|
115 |
+
command += f" --{model=} --{eval_every=} --{batch_size=} --{save_every=} --{max_norm=}"
|
116 |
+
command += f" --{blr=}"
|
117 |
+
if start_from is not None:
|
118 |
+
command += f" --{start_from=}"
|
119 |
+
if not '--no_wandb' in sys.argv:
|
120 |
+
command += f" --wandb_project=3dcomplete " + \
|
121 |
+
f"--{wandb_run_name=}"
|
122 |
+
|
123 |
+
if len(sys.argv) > 1:
|
124 |
+
for arg in sys.argv[1:]:
|
125 |
+
if not '--no_wandb' in arg:
|
126 |
+
command += f" {arg}"
|
127 |
+
print(command)
|
xps/util.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
|
4 |
+
def gen_key(raw_key):
|
5 |
+
# concat the raw_key with the file name that this function is called from
|
6 |
+
current_frame = inspect.currentframe()
|
7 |
+
# Get the caller's frame (the frame that called this function)
|
8 |
+
caller_frame = current_frame.f_back
|
9 |
+
# Extract the filename from the caller's frame
|
10 |
+
caller_file = caller_frame.f_code.co_filename
|
11 |
+
caller_file = os.path.basename(caller_file).replace(".py","")
|
12 |
+
return f"{caller_file}_{raw_key}"
|