Spaces:
Runtime error
Runtime error
liuyuan-pal
commited on
Commit
·
d0f39be
1
Parent(s):
df62e57
add models
Browse files- .gitattributes +1 -0
- blender_script.py +0 -282
- ckpt/ViT-L-14.pt +3 -0
- ckpt/syncdreamer-pretrain.ckpt +3 -0
- foreground_segment.py +0 -50
- raymarching/__init__.py +0 -1
- raymarching/backend.py +0 -40
- raymarching/raymarching.py +0 -373
- raymarching/setup.py +0 -62
- raymarching/src/bindings.cpp +0 -19
- raymarching/src/raymarching.cu +0 -914
- raymarching/src/raymarching.h +0 -18
- render_batch.py +0 -20
- renderer/agg_net.py +0 -83
- renderer/cost_reg_net.py +0 -95
- renderer/dummy_dataset.py +0 -40
- renderer/feature_net.py +0 -42
- renderer/neus_networks.py +0 -503
- renderer/ngp_renderer.py +0 -721
- renderer/renderer.py +0 -604
- requirements.txt +0 -1
- train_renderer.py +0 -187
- train_syncdreamer.py +0 -307
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ckpt/* filter=lfs diff=lfs merge=lfs -text
|
blender_script.py
DELETED
@@ -1,282 +0,0 @@
|
|
1 |
-
"""Blender script to render images of 3D models.
|
2 |
-
|
3 |
-
This script is used to render images of 3D models. It takes in a list of paths
|
4 |
-
to .glb files and renders images of each model. The images are from rotating the
|
5 |
-
object around the origin. The images are saved to the output directory.
|
6 |
-
|
7 |
-
Example usage:
|
8 |
-
blender -b -P blender_script.py -- \
|
9 |
-
--object_path my_object.glb \
|
10 |
-
--output_dir ./views \
|
11 |
-
--engine CYCLES \
|
12 |
-
--scale 0.8 \
|
13 |
-
--num_images 12 \
|
14 |
-
--camera_dist 1.2
|
15 |
-
|
16 |
-
Here, input_model_paths.json is a json file containing a list of paths to .glb.
|
17 |
-
"""
|
18 |
-
|
19 |
-
import argparse
|
20 |
-
import json
|
21 |
-
import math
|
22 |
-
import os
|
23 |
-
import random
|
24 |
-
import sys
|
25 |
-
import time
|
26 |
-
import urllib.request
|
27 |
-
from pathlib import Path
|
28 |
-
|
29 |
-
from mathutils import Vector, Matrix
|
30 |
-
import numpy as np
|
31 |
-
|
32 |
-
import bpy
|
33 |
-
from mathutils import Vector
|
34 |
-
import pickle
|
35 |
-
|
36 |
-
def read_pickle(pkl_path):
|
37 |
-
with open(pkl_path, 'rb') as f:
|
38 |
-
return pickle.load(f)
|
39 |
-
|
40 |
-
def save_pickle(data, pkl_path):
|
41 |
-
# os.system('mkdir -p {}'.format(os.path.dirname(pkl_path)))
|
42 |
-
with open(pkl_path, 'wb') as f:
|
43 |
-
pickle.dump(data, f)
|
44 |
-
|
45 |
-
parser = argparse.ArgumentParser()
|
46 |
-
parser.add_argument("--object_path", type=str, required=True)
|
47 |
-
parser.add_argument("--output_dir", type=str, required=True)
|
48 |
-
parser.add_argument("--engine", type=str, default="CYCLES", choices=["CYCLES", "BLENDER_EEVEE"])
|
49 |
-
parser.add_argument("--camera_type", type=str, default='even')
|
50 |
-
parser.add_argument("--num_images", type=int, default=16)
|
51 |
-
parser.add_argument("--elevation", type=float, default=30)
|
52 |
-
parser.add_argument("--elevation_start", type=float, default=-10)
|
53 |
-
parser.add_argument("--elevation_end", type=float, default=40)
|
54 |
-
parser.add_argument("--device", type=str, default='CUDA')
|
55 |
-
|
56 |
-
argv = sys.argv[sys.argv.index("--") + 1 :]
|
57 |
-
args = parser.parse_args(argv)
|
58 |
-
|
59 |
-
print('===================', args.engine, '===================')
|
60 |
-
|
61 |
-
context = bpy.context
|
62 |
-
scene = context.scene
|
63 |
-
render = scene.render
|
64 |
-
|
65 |
-
cam = scene.objects["Camera"]
|
66 |
-
cam.location = (0, 1.2, 0)
|
67 |
-
cam.data.lens = 35
|
68 |
-
cam.data.sensor_width = 32
|
69 |
-
|
70 |
-
cam_constraint = cam.constraints.new(type="TRACK_TO")
|
71 |
-
cam_constraint.track_axis = "TRACK_NEGATIVE_Z"
|
72 |
-
cam_constraint.up_axis = "UP_Y"
|
73 |
-
|
74 |
-
render.engine = args.engine
|
75 |
-
render.image_settings.file_format = "PNG"
|
76 |
-
render.image_settings.color_mode = "RGBA"
|
77 |
-
render.resolution_x = 256
|
78 |
-
render.resolution_y = 256
|
79 |
-
render.resolution_percentage = 100
|
80 |
-
|
81 |
-
scene.cycles.device = "GPU"
|
82 |
-
scene.cycles.samples = 128
|
83 |
-
scene.cycles.diffuse_bounces = 1
|
84 |
-
scene.cycles.glossy_bounces = 1
|
85 |
-
scene.cycles.transparent_max_bounces = 3
|
86 |
-
scene.cycles.transmission_bounces = 3
|
87 |
-
scene.cycles.filter_width = 0.01
|
88 |
-
scene.cycles.use_denoising = True
|
89 |
-
scene.render.film_transparent = True
|
90 |
-
|
91 |
-
bpy.context.preferences.addons["cycles"].preferences.get_devices()
|
92 |
-
# Set the device_type
|
93 |
-
bpy.context.preferences.addons["cycles"].preferences.compute_device_type = args.device # or "OPENCL"
|
94 |
-
bpy.context.scene.cycles.tile_size = 8192
|
95 |
-
|
96 |
-
|
97 |
-
def az_el_to_points(azimuths, elevations):
|
98 |
-
x = np.cos(azimuths)*np.cos(elevations)
|
99 |
-
y = np.sin(azimuths)*np.cos(elevations)
|
100 |
-
z = np.sin(elevations)
|
101 |
-
return np.stack([x,y,z],-1) #
|
102 |
-
|
103 |
-
def set_camera_location(cam_pt):
|
104 |
-
# from https://blender.stackexchange.com/questions/18530/
|
105 |
-
x, y, z = cam_pt # sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2)
|
106 |
-
camera = bpy.data.objects["Camera"]
|
107 |
-
camera.location = x, y, z
|
108 |
-
|
109 |
-
return camera
|
110 |
-
|
111 |
-
def get_calibration_matrix_K_from_blender(camera):
|
112 |
-
f_in_mm = camera.data.lens
|
113 |
-
scene = bpy.context.scene
|
114 |
-
resolution_x_in_px = scene.render.resolution_x
|
115 |
-
resolution_y_in_px = scene.render.resolution_y
|
116 |
-
scale = scene.render.resolution_percentage / 100
|
117 |
-
sensor_width_in_mm = camera.data.sensor_width
|
118 |
-
sensor_height_in_mm = camera.data.sensor_height
|
119 |
-
pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y
|
120 |
-
|
121 |
-
if camera.data.sensor_fit == 'VERTICAL':
|
122 |
-
# the sensor height is fixed (sensor fit is horizontal),
|
123 |
-
# the sensor width is effectively changed with the pixel aspect ratio
|
124 |
-
s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio
|
125 |
-
s_v = resolution_y_in_px * scale / sensor_height_in_mm
|
126 |
-
else: # 'HORIZONTAL' and 'AUTO'
|
127 |
-
# the sensor width is fixed (sensor fit is horizontal),
|
128 |
-
# the sensor height is effectively changed with the pixel aspect ratio
|
129 |
-
s_u = resolution_x_in_px * scale / sensor_width_in_mm
|
130 |
-
s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm
|
131 |
-
|
132 |
-
# Parameters of intrinsic calibration matrix K
|
133 |
-
alpha_u = f_in_mm * s_u
|
134 |
-
alpha_v = f_in_mm * s_u
|
135 |
-
u_0 = resolution_x_in_px * scale / 2
|
136 |
-
v_0 = resolution_y_in_px * scale / 2
|
137 |
-
skew = 0 # only use rectangular pixels
|
138 |
-
|
139 |
-
K = np.asarray(((alpha_u, skew, u_0),
|
140 |
-
(0, alpha_v, v_0),
|
141 |
-
(0, 0, 1)),np.float32)
|
142 |
-
return K
|
143 |
-
|
144 |
-
|
145 |
-
def reset_scene() -> None:
|
146 |
-
"""Resets the scene to a clean state."""
|
147 |
-
# delete everything that isn't part of a camera or a light
|
148 |
-
for obj in bpy.data.objects:
|
149 |
-
if obj.type not in {"CAMERA", "LIGHT"}:
|
150 |
-
bpy.data.objects.remove(obj, do_unlink=True)
|
151 |
-
# delete all the materials
|
152 |
-
for material in bpy.data.materials:
|
153 |
-
bpy.data.materials.remove(material, do_unlink=True)
|
154 |
-
# delete all the textures
|
155 |
-
for texture in bpy.data.textures:
|
156 |
-
bpy.data.textures.remove(texture, do_unlink=True)
|
157 |
-
# delete all the images
|
158 |
-
for image in bpy.data.images:
|
159 |
-
bpy.data.images.remove(image, do_unlink=True)
|
160 |
-
|
161 |
-
|
162 |
-
# load the glb model
|
163 |
-
def load_object(object_path: str) -> None:
|
164 |
-
"""Loads a glb model into the scene."""
|
165 |
-
if object_path.endswith(".glb"):
|
166 |
-
bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=True)
|
167 |
-
elif object_path.endswith(".fbx"):
|
168 |
-
bpy.ops.import_scene.fbx(filepath=object_path)
|
169 |
-
else:
|
170 |
-
raise ValueError(f"Unsupported file type: {object_path}")
|
171 |
-
|
172 |
-
|
173 |
-
def scene_bbox(single_obj=None, ignore_matrix=False):
|
174 |
-
bbox_min = (math.inf,) * 3
|
175 |
-
bbox_max = (-math.inf,) * 3
|
176 |
-
found = False
|
177 |
-
for obj in scene_meshes() if single_obj is None else [single_obj]:
|
178 |
-
found = True
|
179 |
-
for coord in obj.bound_box:
|
180 |
-
coord = Vector(coord)
|
181 |
-
if not ignore_matrix:
|
182 |
-
coord = obj.matrix_world @ coord
|
183 |
-
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
184 |
-
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
185 |
-
if not found:
|
186 |
-
raise RuntimeError("no objects in scene to compute bounding box for")
|
187 |
-
return Vector(bbox_min), Vector(bbox_max)
|
188 |
-
|
189 |
-
|
190 |
-
def scene_root_objects():
|
191 |
-
for obj in bpy.context.scene.objects.values():
|
192 |
-
if not obj.parent:
|
193 |
-
yield obj
|
194 |
-
|
195 |
-
|
196 |
-
def scene_meshes():
|
197 |
-
for obj in bpy.context.scene.objects.values():
|
198 |
-
if isinstance(obj.data, (bpy.types.Mesh)):
|
199 |
-
yield obj
|
200 |
-
|
201 |
-
# function from https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py
|
202 |
-
def get_3x4_RT_matrix_from_blender(cam):
|
203 |
-
bpy.context.view_layer.update()
|
204 |
-
location, rotation = cam.matrix_world.decompose()[0:2]
|
205 |
-
R = np.asarray(rotation.to_matrix())
|
206 |
-
t = np.asarray(location)
|
207 |
-
|
208 |
-
cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
|
209 |
-
R = R.T
|
210 |
-
t = -R @ t
|
211 |
-
R_world2cv = cam_rec @ R
|
212 |
-
t_world2cv = cam_rec @ t
|
213 |
-
|
214 |
-
RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
|
215 |
-
return RT
|
216 |
-
|
217 |
-
def normalize_scene():
|
218 |
-
bbox_min, bbox_max = scene_bbox()
|
219 |
-
scale = 1 / max(bbox_max - bbox_min)
|
220 |
-
for obj in scene_root_objects():
|
221 |
-
obj.scale = obj.scale * scale
|
222 |
-
# Apply scale to matrix_world.
|
223 |
-
bpy.context.view_layer.update()
|
224 |
-
bbox_min, bbox_max = scene_bbox()
|
225 |
-
offset = -(bbox_min + bbox_max) / 2
|
226 |
-
for obj in scene_root_objects():
|
227 |
-
obj.matrix_world.translation += offset
|
228 |
-
bpy.ops.object.select_all(action="DESELECT")
|
229 |
-
|
230 |
-
def save_images(object_file: str) -> None:
|
231 |
-
object_uid = os.path.basename(object_file).split(".")[0]
|
232 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
233 |
-
|
234 |
-
reset_scene()
|
235 |
-
# load the object
|
236 |
-
load_object(object_file)
|
237 |
-
# object_uid = os.path.basename(object_file).split(".")[0]
|
238 |
-
normalize_scene()
|
239 |
-
|
240 |
-
# create an empty object to track
|
241 |
-
empty = bpy.data.objects.new("Empty", None)
|
242 |
-
scene.collection.objects.link(empty)
|
243 |
-
cam_constraint.target = empty
|
244 |
-
|
245 |
-
world_tree = bpy.context.scene.world.node_tree
|
246 |
-
back_node = world_tree.nodes['Background']
|
247 |
-
env_light = 0.5
|
248 |
-
back_node.inputs['Color'].default_value = Vector([env_light, env_light, env_light, 1.0])
|
249 |
-
back_node.inputs['Strength'].default_value = 1.0
|
250 |
-
|
251 |
-
distances = np.asarray([1.5 for _ in range(args.num_images)])
|
252 |
-
if args.camera_type=='fixed':
|
253 |
-
azimuths = (np.arange(args.num_images)/args.num_images*np.pi*2).astype(np.float32)
|
254 |
-
elevations = np.deg2rad(np.asarray([args.elevation] * args.num_images).astype(np.float32))
|
255 |
-
elif args.camera_type=='random':
|
256 |
-
azimuths = (np.arange(args.num_images) / args.num_images * np.pi * 2).astype(np.float32)
|
257 |
-
elevations = np.random.uniform(args.elevation_start, args.elevation_end, args.num_images)
|
258 |
-
elevations = np.deg2rad(elevations)
|
259 |
-
else:
|
260 |
-
raise NotImplementedError
|
261 |
-
|
262 |
-
cam_pts = az_el_to_points(azimuths, elevations) * distances[:,None]
|
263 |
-
cam_poses = []
|
264 |
-
(Path(args.output_dir) / object_uid).mkdir(exist_ok=True, parents=True)
|
265 |
-
for i in range(args.num_images):
|
266 |
-
# set camera
|
267 |
-
camera = set_camera_location(cam_pts[i])
|
268 |
-
RT = get_3x4_RT_matrix_from_blender(camera)
|
269 |
-
cam_poses.append(RT)
|
270 |
-
|
271 |
-
render_path = os.path.join(args.output_dir, object_uid, f"{i:03d}.png")
|
272 |
-
if os.path.exists(render_path): continue
|
273 |
-
scene.render.filepath = os.path.abspath(render_path)
|
274 |
-
bpy.ops.render.render(write_still=True)
|
275 |
-
|
276 |
-
if args.camera_type=='random':
|
277 |
-
K = get_calibration_matrix_K_from_blender(camera)
|
278 |
-
cam_poses = np.stack(cam_poses, 0)
|
279 |
-
save_pickle([K, azimuths, elevations, distances, cam_poses], os.path.join(args.output_dir, object_uid, "meta.pkl"))
|
280 |
-
|
281 |
-
if __name__ == "__main__":
|
282 |
-
save_images(args.object_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ckpt/ViT-L-14.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836
|
3 |
+
size 932768134
|
ckpt/syncdreamer-pretrain.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ebb31334d9e4002b2590dd805e25238beaf95fa082f6e39a132344624448dcb
|
3 |
+
size 5570034171
|
foreground_segment.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import argparse
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from PIL import Image
|
7 |
-
|
8 |
-
|
9 |
-
class BackgroundRemoval:
|
10 |
-
def __init__(self, device='cuda'):
|
11 |
-
from carvekit.api.high import HiInterface
|
12 |
-
self.interface = HiInterface(
|
13 |
-
object_type="object", # Can be "object" or "hairs-like".
|
14 |
-
batch_size_seg=5,
|
15 |
-
batch_size_matting=1,
|
16 |
-
device=device,
|
17 |
-
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
18 |
-
matting_mask_size=2048,
|
19 |
-
trimap_prob_threshold=231,
|
20 |
-
trimap_dilation=30,
|
21 |
-
trimap_erosion_iters=5,
|
22 |
-
fp16=True,
|
23 |
-
)
|
24 |
-
|
25 |
-
@torch.no_grad()
|
26 |
-
def __call__(self, image):
|
27 |
-
# image: [H, W, 3] array in [0, 255].
|
28 |
-
image = Image.fromarray(image)
|
29 |
-
image = self.interface([image])[0]
|
30 |
-
image = np.array(image)
|
31 |
-
return image
|
32 |
-
|
33 |
-
def process(image_path, mask_path):
|
34 |
-
mask_predictor = BackgroundRemoval()
|
35 |
-
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
36 |
-
if image.shape[-1] == 4:
|
37 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
38 |
-
else:
|
39 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
40 |
-
rgba = mask_predictor(image) # [H, W, 4]
|
41 |
-
cv2.imwrite(mask_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
|
42 |
-
|
43 |
-
|
44 |
-
if __name__ == '__main__':
|
45 |
-
parser = argparse.ArgumentParser()
|
46 |
-
parser.add_argument('--input', required=True, type=str)
|
47 |
-
parser.add_argument('--output', required=True, type=str)
|
48 |
-
opt = parser.parse_args()
|
49 |
-
|
50 |
-
process(opt.input, opt.output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raymarching/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .raymarching import *
|
|
|
|
raymarching/backend.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from torch.utils.cpp_extension import load
|
3 |
-
|
4 |
-
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
-
|
6 |
-
nvcc_flags = [
|
7 |
-
'-O3', '-std=c++14',
|
8 |
-
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
-
]
|
10 |
-
|
11 |
-
if os.name == "posix":
|
12 |
-
c_flags = ['-O3', '-std=c++14']
|
13 |
-
elif os.name == "nt":
|
14 |
-
c_flags = ['/O2', '/std:c++17']
|
15 |
-
|
16 |
-
# find cl.exe
|
17 |
-
def find_cl_path():
|
18 |
-
import glob
|
19 |
-
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
-
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
-
if paths:
|
22 |
-
return paths[0]
|
23 |
-
|
24 |
-
# If cl.exe is not on path, try to find it.
|
25 |
-
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
-
cl_path = find_cl_path()
|
27 |
-
if cl_path is None:
|
28 |
-
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
-
os.environ["PATH"] += ";" + cl_path
|
30 |
-
|
31 |
-
_backend = load(name='_raymarching',
|
32 |
-
extra_cflags=c_flags,
|
33 |
-
extra_cuda_cflags=nvcc_flags,
|
34 |
-
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
-
'raymarching.cu',
|
36 |
-
'bindings.cpp',
|
37 |
-
]],
|
38 |
-
)
|
39 |
-
|
40 |
-
__all__ = ['_backend']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raymarching/raymarching.py
DELETED
@@ -1,373 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import time
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from torch.autograd import Function
|
7 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
-
|
9 |
-
try:
|
10 |
-
import _raymarching as _backend
|
11 |
-
except ImportError:
|
12 |
-
from .backend import _backend
|
13 |
-
|
14 |
-
|
15 |
-
# ----------------------------------------
|
16 |
-
# utils
|
17 |
-
# ----------------------------------------
|
18 |
-
|
19 |
-
class _near_far_from_aabb(Function):
|
20 |
-
@staticmethod
|
21 |
-
@custom_fwd(cast_inputs=torch.float32)
|
22 |
-
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
|
23 |
-
''' near_far_from_aabb, CUDA implementation
|
24 |
-
Calculate rays' intersection time (near and far) with aabb
|
25 |
-
Args:
|
26 |
-
rays_o: float, [N, 3]
|
27 |
-
rays_d: float, [N, 3]
|
28 |
-
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
|
29 |
-
min_near: float, scalar
|
30 |
-
Returns:
|
31 |
-
nears: float, [N]
|
32 |
-
fars: float, [N]
|
33 |
-
'''
|
34 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
35 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
36 |
-
|
37 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
38 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
39 |
-
|
40 |
-
N = rays_o.shape[0] # num rays
|
41 |
-
|
42 |
-
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
43 |
-
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
44 |
-
|
45 |
-
_backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
|
46 |
-
|
47 |
-
return nears, fars
|
48 |
-
|
49 |
-
near_far_from_aabb = _near_far_from_aabb.apply
|
50 |
-
|
51 |
-
|
52 |
-
class _sph_from_ray(Function):
|
53 |
-
@staticmethod
|
54 |
-
@custom_fwd(cast_inputs=torch.float32)
|
55 |
-
def forward(ctx, rays_o, rays_d, radius):
|
56 |
-
''' sph_from_ray, CUDA implementation
|
57 |
-
get spherical coordinate on the background sphere from rays.
|
58 |
-
Assume rays_o are inside the Sphere(radius).
|
59 |
-
Args:
|
60 |
-
rays_o: [N, 3]
|
61 |
-
rays_d: [N, 3]
|
62 |
-
radius: scalar, float
|
63 |
-
Return:
|
64 |
-
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
|
65 |
-
'''
|
66 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
67 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
68 |
-
|
69 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
70 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
71 |
-
|
72 |
-
N = rays_o.shape[0] # num rays
|
73 |
-
|
74 |
-
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
|
75 |
-
|
76 |
-
_backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
|
77 |
-
|
78 |
-
return coords
|
79 |
-
|
80 |
-
sph_from_ray = _sph_from_ray.apply
|
81 |
-
|
82 |
-
|
83 |
-
class _morton3D(Function):
|
84 |
-
@staticmethod
|
85 |
-
def forward(ctx, coords):
|
86 |
-
''' morton3D, CUDA implementation
|
87 |
-
Args:
|
88 |
-
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
|
89 |
-
TODO: check if the coord range is valid! (current 128 is safe)
|
90 |
-
Returns:
|
91 |
-
indices: [N], int32, in [0, 128^3)
|
92 |
-
|
93 |
-
'''
|
94 |
-
if not coords.is_cuda: coords = coords.cuda()
|
95 |
-
|
96 |
-
N = coords.shape[0]
|
97 |
-
|
98 |
-
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
|
99 |
-
|
100 |
-
_backend.morton3D(coords.int(), N, indices)
|
101 |
-
|
102 |
-
return indices
|
103 |
-
|
104 |
-
morton3D = _morton3D.apply
|
105 |
-
|
106 |
-
class _morton3D_invert(Function):
|
107 |
-
@staticmethod
|
108 |
-
def forward(ctx, indices):
|
109 |
-
''' morton3D_invert, CUDA implementation
|
110 |
-
Args:
|
111 |
-
indices: [N], int32, in [0, 128^3)
|
112 |
-
Returns:
|
113 |
-
coords: [N, 3], int32, in [0, 128)
|
114 |
-
|
115 |
-
'''
|
116 |
-
if not indices.is_cuda: indices = indices.cuda()
|
117 |
-
|
118 |
-
N = indices.shape[0]
|
119 |
-
|
120 |
-
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
|
121 |
-
|
122 |
-
_backend.morton3D_invert(indices.int(), N, coords)
|
123 |
-
|
124 |
-
return coords
|
125 |
-
|
126 |
-
morton3D_invert = _morton3D_invert.apply
|
127 |
-
|
128 |
-
|
129 |
-
class _packbits(Function):
|
130 |
-
@staticmethod
|
131 |
-
@custom_fwd(cast_inputs=torch.float32)
|
132 |
-
def forward(ctx, grid, thresh, bitfield=None):
|
133 |
-
''' packbits, CUDA implementation
|
134 |
-
Pack up the density grid into a bit field to accelerate ray marching.
|
135 |
-
Args:
|
136 |
-
grid: float, [C, H * H * H], assume H % 2 == 0
|
137 |
-
thresh: float, threshold
|
138 |
-
Returns:
|
139 |
-
bitfield: uint8, [C, H * H * H / 8]
|
140 |
-
'''
|
141 |
-
if not grid.is_cuda: grid = grid.cuda()
|
142 |
-
grid = grid.contiguous()
|
143 |
-
|
144 |
-
C = grid.shape[0]
|
145 |
-
H3 = grid.shape[1]
|
146 |
-
N = C * H3 // 8
|
147 |
-
|
148 |
-
if bitfield is None:
|
149 |
-
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
|
150 |
-
|
151 |
-
_backend.packbits(grid, N, thresh, bitfield)
|
152 |
-
|
153 |
-
return bitfield
|
154 |
-
|
155 |
-
packbits = _packbits.apply
|
156 |
-
|
157 |
-
# ----------------------------------------
|
158 |
-
# train functions
|
159 |
-
# ----------------------------------------
|
160 |
-
|
161 |
-
class _march_rays_train(Function):
|
162 |
-
@staticmethod
|
163 |
-
@custom_fwd(cast_inputs=torch.float32)
|
164 |
-
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
|
165 |
-
''' march rays to generate points (forward only)
|
166 |
-
Args:
|
167 |
-
rays_o/d: float, [N, 3]
|
168 |
-
bound: float, scalar
|
169 |
-
density_bitfield: uint8: [CHHH // 8]
|
170 |
-
C: int
|
171 |
-
H: int
|
172 |
-
nears/fars: float, [N]
|
173 |
-
step_counter: int32, (2), used to count the actual number of generated points.
|
174 |
-
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
|
175 |
-
perturb: bool
|
176 |
-
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
177 |
-
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
|
178 |
-
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
179 |
-
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
180 |
-
Returns:
|
181 |
-
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
|
182 |
-
dirs: float, [M, 3], all generated points' view dirs.
|
183 |
-
deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
|
184 |
-
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
|
185 |
-
'''
|
186 |
-
|
187 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
188 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
189 |
-
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
|
190 |
-
|
191 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
192 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
193 |
-
density_bitfield = density_bitfield.contiguous()
|
194 |
-
|
195 |
-
N = rays_o.shape[0] # num rays
|
196 |
-
M = N * max_steps # init max points number in total
|
197 |
-
|
198 |
-
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
|
199 |
-
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
|
200 |
-
if not force_all_rays and mean_count > 0:
|
201 |
-
if align > 0:
|
202 |
-
mean_count += align - mean_count % align
|
203 |
-
M = mean_count
|
204 |
-
|
205 |
-
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
206 |
-
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
207 |
-
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
|
208 |
-
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
|
209 |
-
|
210 |
-
if step_counter is None:
|
211 |
-
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
|
212 |
-
|
213 |
-
if perturb:
|
214 |
-
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
|
215 |
-
else:
|
216 |
-
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
|
217 |
-
|
218 |
-
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
|
219 |
-
|
220 |
-
#print(step_counter, M)
|
221 |
-
|
222 |
-
# only used at the first (few) epochs.
|
223 |
-
if force_all_rays or mean_count <= 0:
|
224 |
-
m = step_counter[0].item() # D2H copy
|
225 |
-
if align > 0:
|
226 |
-
m += align - m % align
|
227 |
-
xyzs = xyzs[:m]
|
228 |
-
dirs = dirs[:m]
|
229 |
-
deltas = deltas[:m]
|
230 |
-
|
231 |
-
torch.cuda.empty_cache()
|
232 |
-
|
233 |
-
return xyzs, dirs, deltas, rays
|
234 |
-
|
235 |
-
march_rays_train = _march_rays_train.apply
|
236 |
-
|
237 |
-
|
238 |
-
class _composite_rays_train(Function):
|
239 |
-
@staticmethod
|
240 |
-
@custom_fwd(cast_inputs=torch.float32)
|
241 |
-
def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
|
242 |
-
''' composite rays' rgbs, according to the ray marching formula.
|
243 |
-
Args:
|
244 |
-
rgbs: float, [M, 3]
|
245 |
-
sigmas: float, [M,]
|
246 |
-
deltas: float, [M, 2]
|
247 |
-
rays: int32, [N, 3]
|
248 |
-
Returns:
|
249 |
-
weights_sum: float, [N,], the alpha channel
|
250 |
-
depth: float, [N, ], the Depth
|
251 |
-
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
252 |
-
'''
|
253 |
-
|
254 |
-
sigmas = sigmas.contiguous()
|
255 |
-
rgbs = rgbs.contiguous()
|
256 |
-
|
257 |
-
M = sigmas.shape[0]
|
258 |
-
N = rays.shape[0]
|
259 |
-
|
260 |
-
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
261 |
-
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
262 |
-
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
263 |
-
|
264 |
-
_backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
|
265 |
-
|
266 |
-
ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
|
267 |
-
ctx.dims = [M, N, T_thresh]
|
268 |
-
|
269 |
-
return weights_sum, depth, image
|
270 |
-
|
271 |
-
@staticmethod
|
272 |
-
@custom_bwd
|
273 |
-
def backward(ctx, grad_weights_sum, grad_depth, grad_image):
|
274 |
-
|
275 |
-
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
276 |
-
|
277 |
-
grad_weights_sum = grad_weights_sum.contiguous()
|
278 |
-
grad_image = grad_image.contiguous()
|
279 |
-
|
280 |
-
sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
|
281 |
-
M, N, T_thresh = ctx.dims
|
282 |
-
|
283 |
-
grad_sigmas = torch.zeros_like(sigmas)
|
284 |
-
grad_rgbs = torch.zeros_like(rgbs)
|
285 |
-
|
286 |
-
_backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
|
287 |
-
|
288 |
-
return grad_sigmas, grad_rgbs, None, None, None
|
289 |
-
|
290 |
-
|
291 |
-
composite_rays_train = _composite_rays_train.apply
|
292 |
-
|
293 |
-
# ----------------------------------------
|
294 |
-
# infer functions
|
295 |
-
# ----------------------------------------
|
296 |
-
|
297 |
-
class _march_rays(Function):
|
298 |
-
@staticmethod
|
299 |
-
@custom_fwd(cast_inputs=torch.float32)
|
300 |
-
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
|
301 |
-
''' march rays to generate points (forward only, for inference)
|
302 |
-
Args:
|
303 |
-
n_alive: int, number of alive rays
|
304 |
-
n_step: int, how many steps we march
|
305 |
-
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
|
306 |
-
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
|
307 |
-
rays_o/d: float, [N, 3]
|
308 |
-
bound: float, scalar
|
309 |
-
density_bitfield: uint8: [CHHH // 8]
|
310 |
-
C: int
|
311 |
-
H: int
|
312 |
-
nears/fars: float, [N]
|
313 |
-
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
314 |
-
perturb: bool/int, int > 0 is used as the random seed.
|
315 |
-
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
316 |
-
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
317 |
-
Returns:
|
318 |
-
xyzs: float, [n_alive * n_step, 3], all generated points' coords
|
319 |
-
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
|
320 |
-
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
321 |
-
'''
|
322 |
-
|
323 |
-
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
324 |
-
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
325 |
-
|
326 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
327 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
328 |
-
|
329 |
-
M = n_alive * n_step
|
330 |
-
|
331 |
-
if align > 0:
|
332 |
-
M += align - (M % align)
|
333 |
-
|
334 |
-
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
335 |
-
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
336 |
-
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
|
337 |
-
|
338 |
-
if perturb:
|
339 |
-
# torch.manual_seed(perturb) # test_gui uses spp index as seed
|
340 |
-
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
341 |
-
else:
|
342 |
-
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
343 |
-
|
344 |
-
_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
|
345 |
-
|
346 |
-
return xyzs, dirs, deltas
|
347 |
-
|
348 |
-
march_rays = _march_rays.apply
|
349 |
-
|
350 |
-
|
351 |
-
class _composite_rays(Function):
|
352 |
-
@staticmethod
|
353 |
-
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
354 |
-
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
|
355 |
-
''' composite rays' rgbs, according to the ray marching formula. (for inference)
|
356 |
-
Args:
|
357 |
-
n_alive: int, number of alive rays
|
358 |
-
n_step: int, how many steps we march
|
359 |
-
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
|
360 |
-
rays_t: float, [N], the alive rays' time
|
361 |
-
sigmas: float, [n_alive * n_step,]
|
362 |
-
rgbs: float, [n_alive * n_step, 3]
|
363 |
-
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
364 |
-
In-place Outputs:
|
365 |
-
weights_sum: float, [N,], the alpha channel
|
366 |
-
depth: float, [N,], the depth value
|
367 |
-
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
368 |
-
'''
|
369 |
-
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
|
370 |
-
return tuple()
|
371 |
-
|
372 |
-
|
373 |
-
composite_rays = _composite_rays.apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raymarching/setup.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from setuptools import setup
|
3 |
-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
-
|
5 |
-
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
-
|
7 |
-
nvcc_flags = [
|
8 |
-
'-O3', '-std=c++14',
|
9 |
-
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
-
]
|
11 |
-
|
12 |
-
if os.name == "posix":
|
13 |
-
c_flags = ['-O3', '-std=c++14']
|
14 |
-
elif os.name == "nt":
|
15 |
-
c_flags = ['/O2', '/std:c++17']
|
16 |
-
|
17 |
-
# find cl.exe
|
18 |
-
def find_cl_path():
|
19 |
-
import glob
|
20 |
-
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
-
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
-
if paths:
|
23 |
-
return paths[0]
|
24 |
-
|
25 |
-
# If cl.exe is not on path, try to find it.
|
26 |
-
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
-
cl_path = find_cl_path()
|
28 |
-
if cl_path is None:
|
29 |
-
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
-
os.environ["PATH"] += ";" + cl_path
|
31 |
-
|
32 |
-
'''
|
33 |
-
Usage:
|
34 |
-
|
35 |
-
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
|
36 |
-
|
37 |
-
python setup.py install # build extensions and install (copy) to PATH.
|
38 |
-
pip install . # ditto but better (e.g., dependency & metadata handling)
|
39 |
-
|
40 |
-
python setup.py develop # build extensions and install (symbolic) to PATH.
|
41 |
-
pip install -e . # ditto but better (e.g., dependency & metadata handling)
|
42 |
-
|
43 |
-
'''
|
44 |
-
setup(
|
45 |
-
name='raymarching', # package name, import this to use python API
|
46 |
-
ext_modules=[
|
47 |
-
CUDAExtension(
|
48 |
-
name='_raymarching', # extension name, import this to use CUDA API
|
49 |
-
sources=[os.path.join(_src_path, 'src', f) for f in [
|
50 |
-
'raymarching.cu',
|
51 |
-
'bindings.cpp',
|
52 |
-
]],
|
53 |
-
extra_compile_args={
|
54 |
-
'cxx': c_flags,
|
55 |
-
'nvcc': nvcc_flags,
|
56 |
-
}
|
57 |
-
),
|
58 |
-
],
|
59 |
-
cmdclass={
|
60 |
-
'build_ext': BuildExtension,
|
61 |
-
}
|
62 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raymarching/src/bindings.cpp
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
#include <torch/extension.h>
|
2 |
-
|
3 |
-
#include "raymarching.h"
|
4 |
-
|
5 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
-
// utils
|
7 |
-
m.def("packbits", &packbits, "packbits (CUDA)");
|
8 |
-
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
|
9 |
-
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
|
10 |
-
m.def("morton3D", &morton3D, "morton3D (CUDA)");
|
11 |
-
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
|
12 |
-
// train
|
13 |
-
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
|
14 |
-
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
|
15 |
-
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
|
16 |
-
// infer
|
17 |
-
m.def("march_rays", &march_rays, "march rays (CUDA)");
|
18 |
-
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
|
19 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raymarching/src/raymarching.cu
DELETED
@@ -1,914 +0,0 @@
|
|
1 |
-
#include <cuda.h>
|
2 |
-
#include <cuda_fp16.h>
|
3 |
-
#include <cuda_runtime.h>
|
4 |
-
|
5 |
-
#include <ATen/cuda/CUDAContext.h>
|
6 |
-
#include <torch/torch.h>
|
7 |
-
|
8 |
-
#include <cstdio>
|
9 |
-
#include <stdint.h>
|
10 |
-
#include <stdexcept>
|
11 |
-
#include <limits>
|
12 |
-
|
13 |
-
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
14 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
15 |
-
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
16 |
-
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
17 |
-
|
18 |
-
|
19 |
-
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
|
20 |
-
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
|
21 |
-
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
22 |
-
inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
|
23 |
-
|
24 |
-
|
25 |
-
template <typename T>
|
26 |
-
inline __host__ __device__ T div_round_up(T val, T divisor) {
|
27 |
-
return (val + divisor - 1) / divisor;
|
28 |
-
}
|
29 |
-
|
30 |
-
inline __host__ __device__ float signf(const float x) {
|
31 |
-
return copysignf(1.0, x);
|
32 |
-
}
|
33 |
-
|
34 |
-
inline __host__ __device__ float clamp(const float x, const float min, const float max) {
|
35 |
-
return fminf(max, fmaxf(min, x));
|
36 |
-
}
|
37 |
-
|
38 |
-
inline __host__ __device__ void swapf(float& a, float& b) {
|
39 |
-
float c = a; a = b; b = c;
|
40 |
-
}
|
41 |
-
|
42 |
-
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
|
43 |
-
const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
|
44 |
-
int exponent;
|
45 |
-
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
|
46 |
-
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
47 |
-
}
|
48 |
-
|
49 |
-
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
|
50 |
-
const float mx = dt * H * 0.5;
|
51 |
-
int exponent;
|
52 |
-
frexpf(mx, &exponent);
|
53 |
-
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
54 |
-
}
|
55 |
-
|
56 |
-
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
|
57 |
-
{
|
58 |
-
v = (v * 0x00010001u) & 0xFF0000FFu;
|
59 |
-
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
60 |
-
v = (v * 0x00000011u) & 0xC30C30C3u;
|
61 |
-
v = (v * 0x00000005u) & 0x49249249u;
|
62 |
-
return v;
|
63 |
-
}
|
64 |
-
|
65 |
-
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
|
66 |
-
{
|
67 |
-
uint32_t xx = __expand_bits(x);
|
68 |
-
uint32_t yy = __expand_bits(y);
|
69 |
-
uint32_t zz = __expand_bits(z);
|
70 |
-
return xx | (yy << 1) | (zz << 2);
|
71 |
-
}
|
72 |
-
|
73 |
-
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
|
74 |
-
{
|
75 |
-
x = x & 0x49249249;
|
76 |
-
x = (x | (x >> 2)) & 0xc30c30c3;
|
77 |
-
x = (x | (x >> 4)) & 0x0f00f00f;
|
78 |
-
x = (x | (x >> 8)) & 0xff0000ff;
|
79 |
-
x = (x | (x >> 16)) & 0x0000ffff;
|
80 |
-
return x;
|
81 |
-
}
|
82 |
-
|
83 |
-
|
84 |
-
////////////////////////////////////////////////////
|
85 |
-
///////////// utils /////////////
|
86 |
-
////////////////////////////////////////////////////
|
87 |
-
|
88 |
-
// rays_o/d: [N, 3]
|
89 |
-
// nears/fars: [N]
|
90 |
-
// scalar_t should always be float in use.
|
91 |
-
template <typename scalar_t>
|
92 |
-
__global__ void kernel_near_far_from_aabb(
|
93 |
-
const scalar_t * __restrict__ rays_o,
|
94 |
-
const scalar_t * __restrict__ rays_d,
|
95 |
-
const scalar_t * __restrict__ aabb,
|
96 |
-
const uint32_t N,
|
97 |
-
const float min_near,
|
98 |
-
scalar_t * nears, scalar_t * fars
|
99 |
-
) {
|
100 |
-
// parallel per ray
|
101 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
102 |
-
if (n >= N) return;
|
103 |
-
|
104 |
-
// locate
|
105 |
-
rays_o += n * 3;
|
106 |
-
rays_d += n * 3;
|
107 |
-
|
108 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
109 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
110 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
111 |
-
|
112 |
-
// get near far (assume cube scene)
|
113 |
-
float near = (aabb[0] - ox) * rdx;
|
114 |
-
float far = (aabb[3] - ox) * rdx;
|
115 |
-
if (near > far) swapf(near, far);
|
116 |
-
|
117 |
-
float near_y = (aabb[1] - oy) * rdy;
|
118 |
-
float far_y = (aabb[4] - oy) * rdy;
|
119 |
-
if (near_y > far_y) swapf(near_y, far_y);
|
120 |
-
|
121 |
-
if (near > far_y || near_y > far) {
|
122 |
-
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
123 |
-
return;
|
124 |
-
}
|
125 |
-
|
126 |
-
if (near_y > near) near = near_y;
|
127 |
-
if (far_y < far) far = far_y;
|
128 |
-
|
129 |
-
float near_z = (aabb[2] - oz) * rdz;
|
130 |
-
float far_z = (aabb[5] - oz) * rdz;
|
131 |
-
if (near_z > far_z) swapf(near_z, far_z);
|
132 |
-
|
133 |
-
if (near > far_z || near_z > far) {
|
134 |
-
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
135 |
-
return;
|
136 |
-
}
|
137 |
-
|
138 |
-
if (near_z > near) near = near_z;
|
139 |
-
if (far_z < far) far = far_z;
|
140 |
-
|
141 |
-
if (near < min_near) near = min_near;
|
142 |
-
|
143 |
-
nears[n] = near;
|
144 |
-
fars[n] = far;
|
145 |
-
}
|
146 |
-
|
147 |
-
|
148 |
-
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
|
149 |
-
|
150 |
-
static constexpr uint32_t N_THREAD = 128;
|
151 |
-
|
152 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
153 |
-
rays_o.scalar_type(), "near_far_from_aabb", ([&] {
|
154 |
-
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
|
155 |
-
}));
|
156 |
-
}
|
157 |
-
|
158 |
-
|
159 |
-
// rays_o/d: [N, 3]
|
160 |
-
// radius: float
|
161 |
-
// coords: [N, 2]
|
162 |
-
template <typename scalar_t>
|
163 |
-
__global__ void kernel_sph_from_ray(
|
164 |
-
const scalar_t * __restrict__ rays_o,
|
165 |
-
const scalar_t * __restrict__ rays_d,
|
166 |
-
const float radius,
|
167 |
-
const uint32_t N,
|
168 |
-
scalar_t * coords
|
169 |
-
) {
|
170 |
-
// parallel per ray
|
171 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
172 |
-
if (n >= N) return;
|
173 |
-
|
174 |
-
// locate
|
175 |
-
rays_o += n * 3;
|
176 |
-
rays_d += n * 3;
|
177 |
-
coords += n * 2;
|
178 |
-
|
179 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
180 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
181 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
182 |
-
|
183 |
-
// solve t from || o + td || = radius
|
184 |
-
const float A = dx * dx + dy * dy + dz * dz;
|
185 |
-
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
|
186 |
-
const float C = ox * ox + oy * oy + oz * oz - radius * radius;
|
187 |
-
|
188 |
-
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
|
189 |
-
|
190 |
-
// solve theta, phi (assume y is the up axis)
|
191 |
-
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
|
192 |
-
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
|
193 |
-
const float phi = atan2(z, x); // [-PI, PI)
|
194 |
-
|
195 |
-
// normalize to [-1, 1]
|
196 |
-
coords[0] = 2 * theta * RPI() - 1;
|
197 |
-
coords[1] = phi * RPI();
|
198 |
-
}
|
199 |
-
|
200 |
-
|
201 |
-
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
|
202 |
-
|
203 |
-
static constexpr uint32_t N_THREAD = 128;
|
204 |
-
|
205 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
206 |
-
rays_o.scalar_type(), "sph_from_ray", ([&] {
|
207 |
-
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
|
208 |
-
}));
|
209 |
-
}
|
210 |
-
|
211 |
-
|
212 |
-
// coords: int32, [N, 3]
|
213 |
-
// indices: int32, [N]
|
214 |
-
__global__ void kernel_morton3D(
|
215 |
-
const int * __restrict__ coords,
|
216 |
-
const uint32_t N,
|
217 |
-
int * indices
|
218 |
-
) {
|
219 |
-
// parallel
|
220 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
221 |
-
if (n >= N) return;
|
222 |
-
|
223 |
-
// locate
|
224 |
-
coords += n * 3;
|
225 |
-
indices[n] = __morton3D(coords[0], coords[1], coords[2]);
|
226 |
-
}
|
227 |
-
|
228 |
-
|
229 |
-
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
|
230 |
-
static constexpr uint32_t N_THREAD = 128;
|
231 |
-
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
|
232 |
-
}
|
233 |
-
|
234 |
-
|
235 |
-
// indices: int32, [N]
|
236 |
-
// coords: int32, [N, 3]
|
237 |
-
__global__ void kernel_morton3D_invert(
|
238 |
-
const int * __restrict__ indices,
|
239 |
-
const uint32_t N,
|
240 |
-
int * coords
|
241 |
-
) {
|
242 |
-
// parallel
|
243 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
244 |
-
if (n >= N) return;
|
245 |
-
|
246 |
-
// locate
|
247 |
-
coords += n * 3;
|
248 |
-
|
249 |
-
const int ind = indices[n];
|
250 |
-
|
251 |
-
coords[0] = __morton3D_invert(ind >> 0);
|
252 |
-
coords[1] = __morton3D_invert(ind >> 1);
|
253 |
-
coords[2] = __morton3D_invert(ind >> 2);
|
254 |
-
}
|
255 |
-
|
256 |
-
|
257 |
-
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
|
258 |
-
static constexpr uint32_t N_THREAD = 128;
|
259 |
-
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
|
260 |
-
}
|
261 |
-
|
262 |
-
|
263 |
-
// grid: float, [C, H, H, H]
|
264 |
-
// N: int, C * H * H * H / 8
|
265 |
-
// density_thresh: float
|
266 |
-
// bitfield: uint8, [N]
|
267 |
-
template <typename scalar_t>
|
268 |
-
__global__ void kernel_packbits(
|
269 |
-
const scalar_t * __restrict__ grid,
|
270 |
-
const uint32_t N,
|
271 |
-
const float density_thresh,
|
272 |
-
uint8_t * bitfield
|
273 |
-
) {
|
274 |
-
// parallel per byte
|
275 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
276 |
-
if (n >= N) return;
|
277 |
-
|
278 |
-
// locate
|
279 |
-
grid += n * 8;
|
280 |
-
|
281 |
-
uint8_t bits = 0;
|
282 |
-
|
283 |
-
#pragma unroll
|
284 |
-
for (uint8_t i = 0; i < 8; i++) {
|
285 |
-
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
|
286 |
-
}
|
287 |
-
|
288 |
-
bitfield[n] = bits;
|
289 |
-
}
|
290 |
-
|
291 |
-
|
292 |
-
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
|
293 |
-
|
294 |
-
static constexpr uint32_t N_THREAD = 128;
|
295 |
-
|
296 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
297 |
-
grid.scalar_type(), "packbits", ([&] {
|
298 |
-
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
|
299 |
-
}));
|
300 |
-
}
|
301 |
-
|
302 |
-
////////////////////////////////////////////////////
|
303 |
-
///////////// training /////////////
|
304 |
-
////////////////////////////////////////////////////
|
305 |
-
|
306 |
-
// rays_o/d: [N, 3]
|
307 |
-
// grid: [CHHH / 8]
|
308 |
-
// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
|
309 |
-
// dirs: [M, 3]
|
310 |
-
// rays: [N, 3], idx, offset, num_steps
|
311 |
-
template <typename scalar_t>
|
312 |
-
__global__ void kernel_march_rays_train(
|
313 |
-
const scalar_t * __restrict__ rays_o,
|
314 |
-
const scalar_t * __restrict__ rays_d,
|
315 |
-
const uint8_t * __restrict__ grid,
|
316 |
-
const float bound,
|
317 |
-
const float dt_gamma, const uint32_t max_steps,
|
318 |
-
const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
|
319 |
-
const scalar_t* __restrict__ nears,
|
320 |
-
const scalar_t* __restrict__ fars,
|
321 |
-
scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
|
322 |
-
int * rays,
|
323 |
-
int * counter,
|
324 |
-
const scalar_t* __restrict__ noises
|
325 |
-
) {
|
326 |
-
// parallel per ray
|
327 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
328 |
-
if (n >= N) return;
|
329 |
-
|
330 |
-
// locate
|
331 |
-
rays_o += n * 3;
|
332 |
-
rays_d += n * 3;
|
333 |
-
|
334 |
-
// ray marching
|
335 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
336 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
337 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
338 |
-
const float rH = 1 / (float)H;
|
339 |
-
const float H3 = H * H * H;
|
340 |
-
|
341 |
-
const float near = nears[n];
|
342 |
-
const float far = fars[n];
|
343 |
-
const float noise = noises[n];
|
344 |
-
|
345 |
-
const float dt_min = 2 * SQRT3() / max_steps;
|
346 |
-
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
347 |
-
|
348 |
-
float t0 = near;
|
349 |
-
|
350 |
-
// perturb
|
351 |
-
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
|
352 |
-
|
353 |
-
// first pass: estimation of num_steps
|
354 |
-
float t = t0;
|
355 |
-
uint32_t num_steps = 0;
|
356 |
-
|
357 |
-
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
|
358 |
-
|
359 |
-
while (t < far && num_steps < max_steps) {
|
360 |
-
// current point
|
361 |
-
const float x = clamp(ox + t * dx, -bound, bound);
|
362 |
-
const float y = clamp(oy + t * dy, -bound, bound);
|
363 |
-
const float z = clamp(oz + t * dz, -bound, bound);
|
364 |
-
|
365 |
-
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
366 |
-
|
367 |
-
// get mip level
|
368 |
-
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
369 |
-
|
370 |
-
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
371 |
-
const float mip_rbound = 1 / mip_bound;
|
372 |
-
|
373 |
-
// convert to nearest grid position
|
374 |
-
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
375 |
-
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
376 |
-
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
377 |
-
|
378 |
-
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
379 |
-
const bool occ = grid[index / 8] & (1 << (index % 8));
|
380 |
-
|
381 |
-
// if occpuied, advance a small step, and write to output
|
382 |
-
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
|
383 |
-
|
384 |
-
if (occ) {
|
385 |
-
num_steps++;
|
386 |
-
t += dt;
|
387 |
-
// else, skip a large step (basically skip a voxel grid)
|
388 |
-
} else {
|
389 |
-
// calc distance to next voxel
|
390 |
-
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
391 |
-
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
392 |
-
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
393 |
-
|
394 |
-
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
395 |
-
// step until next voxel
|
396 |
-
do {
|
397 |
-
t += clamp(t * dt_gamma, dt_min, dt_max);
|
398 |
-
} while (t < tt);
|
399 |
-
}
|
400 |
-
}
|
401 |
-
|
402 |
-
//printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
|
403 |
-
|
404 |
-
// second pass: really locate and write points & dirs
|
405 |
-
uint32_t point_index = atomicAdd(counter, num_steps);
|
406 |
-
uint32_t ray_index = atomicAdd(counter + 1, 1);
|
407 |
-
|
408 |
-
//printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
|
409 |
-
|
410 |
-
// write rays
|
411 |
-
rays[ray_index * 3] = n;
|
412 |
-
rays[ray_index * 3 + 1] = point_index;
|
413 |
-
rays[ray_index * 3 + 2] = num_steps;
|
414 |
-
|
415 |
-
if (num_steps == 0) return;
|
416 |
-
if (point_index + num_steps > M) return;
|
417 |
-
|
418 |
-
xyzs += point_index * 3;
|
419 |
-
dirs += point_index * 3;
|
420 |
-
deltas += point_index * 2;
|
421 |
-
|
422 |
-
t = t0;
|
423 |
-
uint32_t step = 0;
|
424 |
-
|
425 |
-
float last_t = t;
|
426 |
-
|
427 |
-
while (t < far && step < num_steps) {
|
428 |
-
// current point
|
429 |
-
const float x = clamp(ox + t * dx, -bound, bound);
|
430 |
-
const float y = clamp(oy + t * dy, -bound, bound);
|
431 |
-
const float z = clamp(oz + t * dz, -bound, bound);
|
432 |
-
|
433 |
-
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
434 |
-
|
435 |
-
// get mip level
|
436 |
-
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
437 |
-
|
438 |
-
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
439 |
-
const float mip_rbound = 1 / mip_bound;
|
440 |
-
|
441 |
-
// convert to nearest grid position
|
442 |
-
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
443 |
-
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
444 |
-
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
445 |
-
|
446 |
-
// query grid
|
447 |
-
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
448 |
-
const bool occ = grid[index / 8] & (1 << (index % 8));
|
449 |
-
|
450 |
-
// if occpuied, advance a small step, and write to output
|
451 |
-
if (occ) {
|
452 |
-
// write step
|
453 |
-
xyzs[0] = x;
|
454 |
-
xyzs[1] = y;
|
455 |
-
xyzs[2] = z;
|
456 |
-
dirs[0] = dx;
|
457 |
-
dirs[1] = dy;
|
458 |
-
dirs[2] = dz;
|
459 |
-
t += dt;
|
460 |
-
deltas[0] = dt;
|
461 |
-
deltas[1] = t - last_t; // used to calc depth
|
462 |
-
last_t = t;
|
463 |
-
xyzs += 3;
|
464 |
-
dirs += 3;
|
465 |
-
deltas += 2;
|
466 |
-
step++;
|
467 |
-
// else, skip a large step (basically skip a voxel grid)
|
468 |
-
} else {
|
469 |
-
// calc distance to next voxel
|
470 |
-
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
471 |
-
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
472 |
-
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
473 |
-
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
474 |
-
// step until next voxel
|
475 |
-
do {
|
476 |
-
t += clamp(t * dt_gamma, dt_min, dt_max);
|
477 |
-
} while (t < tt);
|
478 |
-
}
|
479 |
-
}
|
480 |
-
}
|
481 |
-
|
482 |
-
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
|
483 |
-
|
484 |
-
static constexpr uint32_t N_THREAD = 128;
|
485 |
-
|
486 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
487 |
-
rays_o.scalar_type(), "march_rays_train", ([&] {
|
488 |
-
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
|
489 |
-
}));
|
490 |
-
}
|
491 |
-
|
492 |
-
|
493 |
-
// sigmas: [M]
|
494 |
-
// rgbs: [M, 3]
|
495 |
-
// deltas: [M, 2]
|
496 |
-
// rays: [N, 3], idx, offset, num_steps
|
497 |
-
// weights_sum: [N], final pixel alpha
|
498 |
-
// depth: [N,]
|
499 |
-
// image: [N, 3]
|
500 |
-
template <typename scalar_t>
|
501 |
-
__global__ void kernel_composite_rays_train_forward(
|
502 |
-
const scalar_t * __restrict__ sigmas,
|
503 |
-
const scalar_t * __restrict__ rgbs,
|
504 |
-
const scalar_t * __restrict__ deltas,
|
505 |
-
const int * __restrict__ rays,
|
506 |
-
const uint32_t M, const uint32_t N, const float T_thresh,
|
507 |
-
scalar_t * weights_sum,
|
508 |
-
scalar_t * depth,
|
509 |
-
scalar_t * image
|
510 |
-
) {
|
511 |
-
// parallel per ray
|
512 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
513 |
-
if (n >= N) return;
|
514 |
-
|
515 |
-
// locate
|
516 |
-
uint32_t index = rays[n * 3];
|
517 |
-
uint32_t offset = rays[n * 3 + 1];
|
518 |
-
uint32_t num_steps = rays[n * 3 + 2];
|
519 |
-
|
520 |
-
// empty ray, or ray that exceed max step count.
|
521 |
-
if (num_steps == 0 || offset + num_steps > M) {
|
522 |
-
weights_sum[index] = 0;
|
523 |
-
depth[index] = 0;
|
524 |
-
image[index * 3] = 0;
|
525 |
-
image[index * 3 + 1] = 0;
|
526 |
-
image[index * 3 + 2] = 0;
|
527 |
-
return;
|
528 |
-
}
|
529 |
-
|
530 |
-
sigmas += offset;
|
531 |
-
rgbs += offset * 3;
|
532 |
-
deltas += offset * 2;
|
533 |
-
|
534 |
-
// accumulate
|
535 |
-
uint32_t step = 0;
|
536 |
-
|
537 |
-
scalar_t T = 1.0f;
|
538 |
-
scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
|
539 |
-
|
540 |
-
while (step < num_steps) {
|
541 |
-
|
542 |
-
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
543 |
-
const scalar_t weight = alpha * T;
|
544 |
-
|
545 |
-
r += weight * rgbs[0];
|
546 |
-
g += weight * rgbs[1];
|
547 |
-
b += weight * rgbs[2];
|
548 |
-
|
549 |
-
t += deltas[1]; // real delta
|
550 |
-
d += weight * t;
|
551 |
-
|
552 |
-
ws += weight;
|
553 |
-
|
554 |
-
T *= 1.0f - alpha;
|
555 |
-
|
556 |
-
// minimal remained transmittence
|
557 |
-
if (T < T_thresh) break;
|
558 |
-
|
559 |
-
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
560 |
-
|
561 |
-
// locate
|
562 |
-
sigmas++;
|
563 |
-
rgbs += 3;
|
564 |
-
deltas += 2;
|
565 |
-
|
566 |
-
step++;
|
567 |
-
}
|
568 |
-
|
569 |
-
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
570 |
-
|
571 |
-
// write
|
572 |
-
weights_sum[index] = ws; // weights_sum
|
573 |
-
depth[index] = d;
|
574 |
-
image[index * 3] = r;
|
575 |
-
image[index * 3 + 1] = g;
|
576 |
-
image[index * 3 + 2] = b;
|
577 |
-
}
|
578 |
-
|
579 |
-
|
580 |
-
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
|
581 |
-
|
582 |
-
static constexpr uint32_t N_THREAD = 128;
|
583 |
-
|
584 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
585 |
-
sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
|
586 |
-
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
587 |
-
}));
|
588 |
-
}
|
589 |
-
|
590 |
-
|
591 |
-
// grad_weights_sum: [N,]
|
592 |
-
// grad: [N, 3]
|
593 |
-
// sigmas: [M]
|
594 |
-
// rgbs: [M, 3]
|
595 |
-
// deltas: [M, 2]
|
596 |
-
// rays: [N, 3], idx, offset, num_steps
|
597 |
-
// weights_sum: [N,], weights_sum here
|
598 |
-
// image: [N, 3]
|
599 |
-
// grad_sigmas: [M]
|
600 |
-
// grad_rgbs: [M, 3]
|
601 |
-
template <typename scalar_t>
|
602 |
-
__global__ void kernel_composite_rays_train_backward(
|
603 |
-
const scalar_t * __restrict__ grad_weights_sum,
|
604 |
-
const scalar_t * __restrict__ grad_image,
|
605 |
-
const scalar_t * __restrict__ sigmas,
|
606 |
-
const scalar_t * __restrict__ rgbs,
|
607 |
-
const scalar_t * __restrict__ deltas,
|
608 |
-
const int * __restrict__ rays,
|
609 |
-
const scalar_t * __restrict__ weights_sum,
|
610 |
-
const scalar_t * __restrict__ image,
|
611 |
-
const uint32_t M, const uint32_t N, const float T_thresh,
|
612 |
-
scalar_t * grad_sigmas,
|
613 |
-
scalar_t * grad_rgbs
|
614 |
-
) {
|
615 |
-
// parallel per ray
|
616 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
617 |
-
if (n >= N) return;
|
618 |
-
|
619 |
-
// locate
|
620 |
-
uint32_t index = rays[n * 3];
|
621 |
-
uint32_t offset = rays[n * 3 + 1];
|
622 |
-
uint32_t num_steps = rays[n * 3 + 2];
|
623 |
-
|
624 |
-
if (num_steps == 0 || offset + num_steps > M) return;
|
625 |
-
|
626 |
-
grad_weights_sum += index;
|
627 |
-
grad_image += index * 3;
|
628 |
-
weights_sum += index;
|
629 |
-
image += index * 3;
|
630 |
-
sigmas += offset;
|
631 |
-
rgbs += offset * 3;
|
632 |
-
deltas += offset * 2;
|
633 |
-
grad_sigmas += offset;
|
634 |
-
grad_rgbs += offset * 3;
|
635 |
-
|
636 |
-
// accumulate
|
637 |
-
uint32_t step = 0;
|
638 |
-
|
639 |
-
scalar_t T = 1.0f;
|
640 |
-
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
|
641 |
-
scalar_t r = 0, g = 0, b = 0, ws = 0;
|
642 |
-
|
643 |
-
while (step < num_steps) {
|
644 |
-
|
645 |
-
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
646 |
-
const scalar_t weight = alpha * T;
|
647 |
-
|
648 |
-
r += weight * rgbs[0];
|
649 |
-
g += weight * rgbs[1];
|
650 |
-
b += weight * rgbs[2];
|
651 |
-
ws += weight;
|
652 |
-
|
653 |
-
T *= 1.0f - alpha;
|
654 |
-
|
655 |
-
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
|
656 |
-
// write grad_rgbs
|
657 |
-
grad_rgbs[0] = grad_image[0] * weight;
|
658 |
-
grad_rgbs[1] = grad_image[1] * weight;
|
659 |
-
grad_rgbs[2] = grad_image[2] * weight;
|
660 |
-
|
661 |
-
// write grad_sigmas
|
662 |
-
grad_sigmas[0] = deltas[0] * (
|
663 |
-
grad_image[0] * (T * rgbs[0] - (r_final - r)) +
|
664 |
-
grad_image[1] * (T * rgbs[1] - (g_final - g)) +
|
665 |
-
grad_image[2] * (T * rgbs[2] - (b_final - b)) +
|
666 |
-
grad_weights_sum[0] * (1 - ws_final)
|
667 |
-
);
|
668 |
-
|
669 |
-
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
|
670 |
-
// minimal remained transmittence
|
671 |
-
if (T < T_thresh) break;
|
672 |
-
|
673 |
-
// locate
|
674 |
-
sigmas++;
|
675 |
-
rgbs += 3;
|
676 |
-
deltas += 2;
|
677 |
-
grad_sigmas++;
|
678 |
-
grad_rgbs += 3;
|
679 |
-
|
680 |
-
step++;
|
681 |
-
}
|
682 |
-
}
|
683 |
-
|
684 |
-
|
685 |
-
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
|
686 |
-
|
687 |
-
static constexpr uint32_t N_THREAD = 128;
|
688 |
-
|
689 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
690 |
-
grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
|
691 |
-
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
|
692 |
-
}));
|
693 |
-
}
|
694 |
-
|
695 |
-
|
696 |
-
////////////////////////////////////////////////////
|
697 |
-
///////////// infernce /////////////
|
698 |
-
////////////////////////////////////////////////////
|
699 |
-
|
700 |
-
template <typename scalar_t>
|
701 |
-
__global__ void kernel_march_rays(
|
702 |
-
const uint32_t n_alive,
|
703 |
-
const uint32_t n_step,
|
704 |
-
const int* __restrict__ rays_alive,
|
705 |
-
const scalar_t* __restrict__ rays_t,
|
706 |
-
const scalar_t* __restrict__ rays_o,
|
707 |
-
const scalar_t* __restrict__ rays_d,
|
708 |
-
const float bound,
|
709 |
-
const float dt_gamma, const uint32_t max_steps,
|
710 |
-
const uint32_t C, const uint32_t H,
|
711 |
-
const uint8_t * __restrict__ grid,
|
712 |
-
const scalar_t* __restrict__ nears,
|
713 |
-
const scalar_t* __restrict__ fars,
|
714 |
-
scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
|
715 |
-
const scalar_t* __restrict__ noises
|
716 |
-
) {
|
717 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
718 |
-
if (n >= n_alive) return;
|
719 |
-
|
720 |
-
const int index = rays_alive[n]; // ray id
|
721 |
-
const float noise = noises[n];
|
722 |
-
|
723 |
-
// locate
|
724 |
-
rays_o += index * 3;
|
725 |
-
rays_d += index * 3;
|
726 |
-
xyzs += n * n_step * 3;
|
727 |
-
dirs += n * n_step * 3;
|
728 |
-
deltas += n * n_step * 2;
|
729 |
-
|
730 |
-
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
731 |
-
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
732 |
-
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
733 |
-
const float rH = 1 / (float)H;
|
734 |
-
const float H3 = H * H * H;
|
735 |
-
|
736 |
-
float t = rays_t[index]; // current ray's t
|
737 |
-
const float near = nears[index], far = fars[index];
|
738 |
-
|
739 |
-
const float dt_min = 2 * SQRT3() / max_steps;
|
740 |
-
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
741 |
-
|
742 |
-
// march for n_step steps, record points
|
743 |
-
uint32_t step = 0;
|
744 |
-
|
745 |
-
// introduce some randomness
|
746 |
-
t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
|
747 |
-
|
748 |
-
float last_t = t;
|
749 |
-
|
750 |
-
while (t < far && step < n_step) {
|
751 |
-
// current point
|
752 |
-
const float x = clamp(ox + t * dx, -bound, bound);
|
753 |
-
const float y = clamp(oy + t * dy, -bound, bound);
|
754 |
-
const float z = clamp(oz + t * dz, -bound, bound);
|
755 |
-
|
756 |
-
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
757 |
-
|
758 |
-
// get mip level
|
759 |
-
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
760 |
-
|
761 |
-
const float mip_bound = fminf(scalbnf(1, level), bound);
|
762 |
-
const float mip_rbound = 1 / mip_bound;
|
763 |
-
|
764 |
-
// convert to nearest grid position
|
765 |
-
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
766 |
-
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
767 |
-
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
768 |
-
|
769 |
-
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
770 |
-
const bool occ = grid[index / 8] & (1 << (index % 8));
|
771 |
-
|
772 |
-
// if occpuied, advance a small step, and write to output
|
773 |
-
if (occ) {
|
774 |
-
// write step
|
775 |
-
xyzs[0] = x;
|
776 |
-
xyzs[1] = y;
|
777 |
-
xyzs[2] = z;
|
778 |
-
dirs[0] = dx;
|
779 |
-
dirs[1] = dy;
|
780 |
-
dirs[2] = dz;
|
781 |
-
// calc dt
|
782 |
-
t += dt;
|
783 |
-
deltas[0] = dt;
|
784 |
-
deltas[1] = t - last_t; // used to calc depth
|
785 |
-
last_t = t;
|
786 |
-
// step
|
787 |
-
xyzs += 3;
|
788 |
-
dirs += 3;
|
789 |
-
deltas += 2;
|
790 |
-
step++;
|
791 |
-
|
792 |
-
// else, skip a large step (basically skip a voxel grid)
|
793 |
-
} else {
|
794 |
-
// calc distance to next voxel
|
795 |
-
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
796 |
-
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
797 |
-
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
798 |
-
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
799 |
-
// step until next voxel
|
800 |
-
do {
|
801 |
-
t += clamp(t * dt_gamma, dt_min, dt_max);
|
802 |
-
} while (t < tt);
|
803 |
-
}
|
804 |
-
}
|
805 |
-
}
|
806 |
-
|
807 |
-
|
808 |
-
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
|
809 |
-
static constexpr uint32_t N_THREAD = 128;
|
810 |
-
|
811 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
812 |
-
rays_o.scalar_type(), "march_rays", ([&] {
|
813 |
-
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
|
814 |
-
}));
|
815 |
-
}
|
816 |
-
|
817 |
-
|
818 |
-
template <typename scalar_t>
|
819 |
-
__global__ void kernel_composite_rays(
|
820 |
-
const uint32_t n_alive,
|
821 |
-
const uint32_t n_step,
|
822 |
-
const float T_thresh,
|
823 |
-
int* rays_alive,
|
824 |
-
scalar_t* rays_t,
|
825 |
-
const scalar_t* __restrict__ sigmas,
|
826 |
-
const scalar_t* __restrict__ rgbs,
|
827 |
-
const scalar_t* __restrict__ deltas,
|
828 |
-
scalar_t* weights_sum, scalar_t* depth, scalar_t* image
|
829 |
-
) {
|
830 |
-
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
831 |
-
if (n >= n_alive) return;
|
832 |
-
|
833 |
-
const int index = rays_alive[n]; // ray id
|
834 |
-
|
835 |
-
// locate
|
836 |
-
sigmas += n * n_step;
|
837 |
-
rgbs += n * n_step * 3;
|
838 |
-
deltas += n * n_step * 2;
|
839 |
-
|
840 |
-
rays_t += index;
|
841 |
-
weights_sum += index;
|
842 |
-
depth += index;
|
843 |
-
image += index * 3;
|
844 |
-
|
845 |
-
scalar_t t = rays_t[0]; // current ray's t
|
846 |
-
|
847 |
-
scalar_t weight_sum = weights_sum[0];
|
848 |
-
scalar_t d = depth[0];
|
849 |
-
scalar_t r = image[0];
|
850 |
-
scalar_t g = image[1];
|
851 |
-
scalar_t b = image[2];
|
852 |
-
|
853 |
-
// accumulate
|
854 |
-
uint32_t step = 0;
|
855 |
-
while (step < n_step) {
|
856 |
-
|
857 |
-
// ray is terminated if delta == 0
|
858 |
-
if (deltas[0] == 0) break;
|
859 |
-
|
860 |
-
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
861 |
-
|
862 |
-
/*
|
863 |
-
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
|
864 |
-
w_i = alpha_i * T_i
|
865 |
-
-->
|
866 |
-
T_i = 1 - \sum_{j=0}^{i-1} w_j
|
867 |
-
*/
|
868 |
-
const scalar_t T = 1 - weight_sum;
|
869 |
-
const scalar_t weight = alpha * T;
|
870 |
-
weight_sum += weight;
|
871 |
-
|
872 |
-
t += deltas[1]; // real delta
|
873 |
-
d += weight * t;
|
874 |
-
r += weight * rgbs[0];
|
875 |
-
g += weight * rgbs[1];
|
876 |
-
b += weight * rgbs[2];
|
877 |
-
|
878 |
-
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
879 |
-
|
880 |
-
// ray is terminated if T is too small
|
881 |
-
// use a larger bound to further accelerate inference
|
882 |
-
if (T < T_thresh) break;
|
883 |
-
|
884 |
-
// locate
|
885 |
-
sigmas++;
|
886 |
-
rgbs += 3;
|
887 |
-
deltas += 2;
|
888 |
-
step++;
|
889 |
-
}
|
890 |
-
|
891 |
-
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
892 |
-
|
893 |
-
// rays_alive = -1 means ray is terminated early.
|
894 |
-
if (step < n_step) {
|
895 |
-
rays_alive[n] = -1;
|
896 |
-
} else {
|
897 |
-
rays_t[0] = t;
|
898 |
-
}
|
899 |
-
|
900 |
-
weights_sum[0] = weight_sum; // this is the thing I needed!
|
901 |
-
depth[0] = d;
|
902 |
-
image[0] = r;
|
903 |
-
image[1] = g;
|
904 |
-
image[2] = b;
|
905 |
-
}
|
906 |
-
|
907 |
-
|
908 |
-
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
|
909 |
-
static constexpr uint32_t N_THREAD = 128;
|
910 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
911 |
-
image.scalar_type(), "composite_rays", ([&] {
|
912 |
-
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
913 |
-
}));
|
914 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raymarching/src/raymarching.h
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
#pragma once
|
2 |
-
|
3 |
-
#include <stdint.h>
|
4 |
-
#include <torch/torch.h>
|
5 |
-
|
6 |
-
|
7 |
-
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
|
8 |
-
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
|
9 |
-
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
|
10 |
-
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
|
11 |
-
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
|
12 |
-
|
13 |
-
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
|
14 |
-
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
15 |
-
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
|
16 |
-
|
17 |
-
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
|
18 |
-
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
render_batch.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import subprocess
|
2 |
-
|
3 |
-
from ldm.base_utils import save_pickle
|
4 |
-
|
5 |
-
uids=['6f99fb8c2f1a4252b986ed5a765e1db9','8bba4678f9a349d6a29314ccf337975c','063b1b7d877a402ead76cedb06341681',
|
6 |
-
'199b7a080622422fac8140b61cc7544a','83784b6f7a064212ab50aaaaeb1d7fa7','5501434a052c49d6a8a8d9a1120fee10',
|
7 |
-
'cca62f95635f4b20aea4f35014632a55','d2e8612a21044111a7176da2bd78de05','f9e172dd733644a2b47a824e202c89d5']
|
8 |
-
|
9 |
-
# for uid in uids:
|
10 |
-
# cmds = ['blender','--background','--python','blender_script.py','--',
|
11 |
-
# '--object_path',f'objaverse_examples/{uid}/{uid}.glb',
|
12 |
-
# '--output_dir','./training_examples/input','--camera_type','random']
|
13 |
-
# subprocess.run(cmds)
|
14 |
-
#
|
15 |
-
# cmds = ['blender','--background','--python','blender_script.py','--',
|
16 |
-
# '--object_path',f'objaverse_examples/{uid}/{uid}.glb',
|
17 |
-
# '--output_dir','./training_examples/target','--camera_type','fixed']
|
18 |
-
# subprocess.run(cmds)
|
19 |
-
|
20 |
-
save_pickle(uids, f'training_examples/uid_set.pkl')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/agg_net.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import torch.nn.functional as F
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch
|
4 |
-
|
5 |
-
def weights_init(m):
|
6 |
-
if isinstance(m, nn.Linear):
|
7 |
-
nn.init.kaiming_normal_(m.weight.data)
|
8 |
-
if m.bias is not None:
|
9 |
-
nn.init.zeros_(m.bias.data)
|
10 |
-
|
11 |
-
class NeRF(nn.Module):
|
12 |
-
def __init__(self, vol_n=8+8, feat_ch=8+16+32+3, hid_n=64):
|
13 |
-
super(NeRF, self).__init__()
|
14 |
-
self.hid_n = hid_n
|
15 |
-
self.agg = Agg(feat_ch)
|
16 |
-
self.lr0 = nn.Sequential(nn.Linear(vol_n+16, hid_n), nn.ReLU())
|
17 |
-
self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus())
|
18 |
-
self.color = nn.Sequential(
|
19 |
-
nn.Linear(16+vol_n+feat_ch+hid_n+4, hid_n), # agg_feats+vox_feat+img_feat+lr0_feats+dir
|
20 |
-
nn.ReLU(),
|
21 |
-
nn.Linear(hid_n, 1)
|
22 |
-
)
|
23 |
-
self.lr0.apply(weights_init)
|
24 |
-
self.sigma.apply(weights_init)
|
25 |
-
self.color.apply(weights_init)
|
26 |
-
|
27 |
-
def forward(self, vox_feat, img_feat_rgb_dir, source_img_mask):
|
28 |
-
# assert torch.sum(torch.sum(source_img_mask,1)<2)==0
|
29 |
-
b, d, n, _ = img_feat_rgb_dir.shape # b,d,n,f=8+16+32+3+4
|
30 |
-
agg_feat = self.agg(img_feat_rgb_dir, source_img_mask) # b,d,f=16
|
31 |
-
x = self.lr0(torch.cat((vox_feat, agg_feat), dim=-1)) # b,d,f=64
|
32 |
-
sigma = self.sigma(x) # b,d,1
|
33 |
-
|
34 |
-
x = torch.cat((x, vox_feat, agg_feat), dim=-1) # b,d,f=16+16+64
|
35 |
-
x = x.view(b, d, 1, x.shape[-1]).repeat(1, 1, n, 1)
|
36 |
-
x = torch.cat((x, img_feat_rgb_dir), dim=-1)
|
37 |
-
logits = self.color(x)
|
38 |
-
source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
|
39 |
-
logits[source_img_mask_] = -1e7
|
40 |
-
color_weight = F.softmax(logits, dim=-2)
|
41 |
-
color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2)
|
42 |
-
return color, sigma
|
43 |
-
|
44 |
-
class Agg(nn.Module):
|
45 |
-
def __init__(self, feat_ch):
|
46 |
-
super(Agg, self).__init__()
|
47 |
-
self.feat_ch = feat_ch
|
48 |
-
self.view_fc = nn.Sequential(nn.Linear(4, feat_ch), nn.ReLU())
|
49 |
-
self.view_fc.apply(weights_init)
|
50 |
-
self.global_fc = nn.Sequential(nn.Linear(feat_ch*3, 32), nn.ReLU())
|
51 |
-
|
52 |
-
self.agg_w_fc = nn.Linear(32, 1)
|
53 |
-
self.fc = nn.Linear(32, 16)
|
54 |
-
self.global_fc.apply(weights_init)
|
55 |
-
self.agg_w_fc.apply(weights_init)
|
56 |
-
self.fc.apply(weights_init)
|
57 |
-
|
58 |
-
def masked_mean_var(self, img_feat_rgb, source_img_mask):
|
59 |
-
# img_feat_rgb: b,d,n,f source_img_mask: b,n
|
60 |
-
b, n = source_img_mask.shape
|
61 |
-
source_img_mask = source_img_mask.view(b, 1, n, 1)
|
62 |
-
mean = torch.sum(source_img_mask * img_feat_rgb, dim=-2)/ (torch.sum(source_img_mask, dim=-2) + 1e-5)
|
63 |
-
var = torch.sum((img_feat_rgb - mean.unsqueeze(-2)) ** 2 * source_img_mask, dim=-2) / (torch.sum(source_img_mask, dim=-2) + 1e-5)
|
64 |
-
return mean, var
|
65 |
-
|
66 |
-
def forward(self, img_feat_rgb_dir, source_img_mask):
|
67 |
-
# img_feat_rgb_dir b,d,n,f
|
68 |
-
b, d, n, _ = img_feat_rgb_dir.shape
|
69 |
-
view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) # b,d,n,f-4
|
70 |
-
img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat
|
71 |
-
|
72 |
-
mean_feat, var_feat = self.masked_mean_var(img_feat_rgb, source_img_mask)
|
73 |
-
var_feat = var_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
|
74 |
-
avg_feat = mean_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1)
|
75 |
-
|
76 |
-
feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) # b,d,n,f
|
77 |
-
global_feat = self.global_fc(feat) # b,d,n,f
|
78 |
-
logits = self.agg_w_fc(global_feat) # b,d,n,1
|
79 |
-
source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0
|
80 |
-
logits[source_img_mask_] = -1e7
|
81 |
-
agg_w = F.softmax(logits, dim=-2)
|
82 |
-
im_feat = (global_feat * agg_w).sum(dim=-2)
|
83 |
-
return self.fc(im_feat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/cost_reg_net.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
|
3 |
-
class ConvBnReLU3D(nn.Module):
|
4 |
-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm3d):
|
5 |
-
super(ConvBnReLU3D, self).__init__()
|
6 |
-
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
|
7 |
-
self.bn = norm_act(out_channels)
|
8 |
-
self.relu = nn.ReLU(inplace=True)
|
9 |
-
|
10 |
-
def forward(self, x):
|
11 |
-
return self.relu(self.bn(self.conv(x)))
|
12 |
-
|
13 |
-
class CostRegNet(nn.Module):
|
14 |
-
def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
|
15 |
-
super(CostRegNet, self).__init__()
|
16 |
-
self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
|
17 |
-
|
18 |
-
self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
|
19 |
-
self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
|
20 |
-
|
21 |
-
self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
|
22 |
-
self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
|
23 |
-
|
24 |
-
self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
|
25 |
-
self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
|
26 |
-
|
27 |
-
self.conv7 = nn.Sequential(
|
28 |
-
nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, stride=2, bias=False),
|
29 |
-
norm_act(32)
|
30 |
-
)
|
31 |
-
|
32 |
-
self.conv9 = nn.Sequential(
|
33 |
-
nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, stride=2, bias=False),
|
34 |
-
norm_act(16)
|
35 |
-
)
|
36 |
-
|
37 |
-
self.conv11 = nn.Sequential(
|
38 |
-
nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,stride=2, bias=False),
|
39 |
-
norm_act(8)
|
40 |
-
)
|
41 |
-
self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
|
42 |
-
self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
|
43 |
-
|
44 |
-
def forward(self, x):
|
45 |
-
conv0 = self.conv0(x)
|
46 |
-
conv2 = self.conv2(self.conv1(conv0))
|
47 |
-
conv4 = self.conv4(self.conv3(conv2))
|
48 |
-
x = self.conv6(self.conv5(conv4))
|
49 |
-
x = conv4 + self.conv7(x)
|
50 |
-
del conv4
|
51 |
-
x = conv2 + self.conv9(x)
|
52 |
-
del conv2
|
53 |
-
x = conv0 + self.conv11(x)
|
54 |
-
del conv0
|
55 |
-
feat = self.feat_conv(x)
|
56 |
-
depth = self.depth_conv(x)
|
57 |
-
return feat, depth
|
58 |
-
|
59 |
-
|
60 |
-
class MinCostRegNet(nn.Module):
|
61 |
-
def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
|
62 |
-
super(MinCostRegNet, self).__init__()
|
63 |
-
self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
|
64 |
-
|
65 |
-
self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
|
66 |
-
self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
|
67 |
-
|
68 |
-
self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
|
69 |
-
self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
|
70 |
-
|
71 |
-
self.conv9 = nn.Sequential(
|
72 |
-
nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
|
73 |
-
stride=2, bias=False),
|
74 |
-
norm_act(16))
|
75 |
-
|
76 |
-
self.conv11 = nn.Sequential(
|
77 |
-
nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
|
78 |
-
stride=2, bias=False),
|
79 |
-
norm_act(8))
|
80 |
-
|
81 |
-
self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
|
82 |
-
self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
|
83 |
-
|
84 |
-
def forward(self, x):
|
85 |
-
conv0 = self.conv0(x)
|
86 |
-
conv2 = self.conv2(self.conv1(conv0))
|
87 |
-
conv4 = self.conv4(self.conv3(conv2))
|
88 |
-
x = conv4
|
89 |
-
x = conv2 + self.conv9(x)
|
90 |
-
del conv2
|
91 |
-
x = conv0 + self.conv11(x)
|
92 |
-
del conv0
|
93 |
-
feat = self.feat_conv(x)
|
94 |
-
depth = self.depth_conv(x)
|
95 |
-
return feat, depth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/dummy_dataset.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import pytorch_lightning as pl
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
import webdataset as wds
|
4 |
-
from torch.utils.data.distributed import DistributedSampler
|
5 |
-
class DummyDataset(pl.LightningDataModule):
|
6 |
-
def __init__(self,seed):
|
7 |
-
super().__init__()
|
8 |
-
|
9 |
-
def setup(self, stage):
|
10 |
-
if stage in ['fit']:
|
11 |
-
self.train_dataset = DummyData(True)
|
12 |
-
self.val_dataset = DummyData(False)
|
13 |
-
else:
|
14 |
-
raise NotImplementedError
|
15 |
-
|
16 |
-
def train_dataloader(self):
|
17 |
-
return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False)
|
18 |
-
|
19 |
-
def val_dataloader(self):
|
20 |
-
return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False)
|
21 |
-
|
22 |
-
def test_dataloader(self):
|
23 |
-
return wds.WebLoader(DummyData(False))
|
24 |
-
|
25 |
-
class DummyData(Dataset):
|
26 |
-
def __init__(self,is_train):
|
27 |
-
self.is_train=is_train
|
28 |
-
|
29 |
-
def __len__(self):
|
30 |
-
if self.is_train:
|
31 |
-
return 99999999
|
32 |
-
else:
|
33 |
-
return 1
|
34 |
-
|
35 |
-
def __getitem__(self, index):
|
36 |
-
return {}
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/feature_net.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import torch.nn as nn
|
2 |
-
import torch.nn.functional as F
|
3 |
-
|
4 |
-
class ConvBnReLU(nn.Module):
|
5 |
-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm2d):
|
6 |
-
super(ConvBnReLU, self).__init__()
|
7 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
|
8 |
-
self.bn = norm_act(out_channels)
|
9 |
-
self.relu = nn.ReLU(inplace=True)
|
10 |
-
|
11 |
-
def forward(self, x):
|
12 |
-
return self.relu(self.bn(self.conv(x)))
|
13 |
-
|
14 |
-
class FeatureNet(nn.Module):
|
15 |
-
def __init__(self, norm_act=nn.BatchNorm2d):
|
16 |
-
super(FeatureNet, self).__init__()
|
17 |
-
self.conv0 = nn.Sequential(ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
|
18 |
-
self.conv1 = nn.Sequential(ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
|
19 |
-
self.conv2 = nn.Sequential(ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
|
20 |
-
|
21 |
-
self.toplayer = nn.Conv2d(32, 32, 1)
|
22 |
-
self.lat1 = nn.Conv2d(16, 32, 1)
|
23 |
-
self.lat0 = nn.Conv2d(8, 32, 1)
|
24 |
-
|
25 |
-
self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
|
26 |
-
self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
|
27 |
-
|
28 |
-
def _upsample_add(self, x, y):
|
29 |
-
return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
|
30 |
-
|
31 |
-
def forward(self, x):
|
32 |
-
conv0 = self.conv0(x)
|
33 |
-
conv1 = self.conv1(conv0)
|
34 |
-
conv2 = self.conv2(conv1)
|
35 |
-
feat2 = self.toplayer(conv2)
|
36 |
-
feat1 = self._upsample_add(feat2, self.lat1(conv1))
|
37 |
-
feat0 = self._upsample_add(feat1, self.lat0(conv0))
|
38 |
-
feat1 = self.smooth1(feat1)
|
39 |
-
feat0 = self.smooth0(feat0)
|
40 |
-
return feat2, feat1, feat0
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/neus_networks.py
DELETED
@@ -1,503 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import tinycudann as tcnn
|
8 |
-
|
9 |
-
# Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
|
10 |
-
class Embedder:
|
11 |
-
def __init__(self, **kwargs):
|
12 |
-
self.kwargs = kwargs
|
13 |
-
self.create_embedding_fn()
|
14 |
-
|
15 |
-
def create_embedding_fn(self):
|
16 |
-
embed_fns = []
|
17 |
-
d = self.kwargs['input_dims']
|
18 |
-
out_dim = 0
|
19 |
-
if self.kwargs['include_input']:
|
20 |
-
embed_fns.append(lambda x: x)
|
21 |
-
out_dim += d
|
22 |
-
|
23 |
-
max_freq = self.kwargs['max_freq_log2']
|
24 |
-
N_freqs = self.kwargs['num_freqs']
|
25 |
-
|
26 |
-
if self.kwargs['log_sampling']:
|
27 |
-
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
28 |
-
else:
|
29 |
-
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
|
30 |
-
|
31 |
-
for freq in freq_bands:
|
32 |
-
for p_fn in self.kwargs['periodic_fns']:
|
33 |
-
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
34 |
-
out_dim += d
|
35 |
-
|
36 |
-
self.embed_fns = embed_fns
|
37 |
-
self.out_dim = out_dim
|
38 |
-
|
39 |
-
def embed(self, inputs):
|
40 |
-
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
41 |
-
|
42 |
-
|
43 |
-
def get_embedder(multires, input_dims=3):
|
44 |
-
embed_kwargs = {
|
45 |
-
'include_input': True,
|
46 |
-
'input_dims': input_dims,
|
47 |
-
'max_freq_log2': multires - 1,
|
48 |
-
'num_freqs': multires,
|
49 |
-
'log_sampling': True,
|
50 |
-
'periodic_fns': [torch.sin, torch.cos],
|
51 |
-
}
|
52 |
-
|
53 |
-
embedder_obj = Embedder(**embed_kwargs)
|
54 |
-
|
55 |
-
def embed(x, eo=embedder_obj): return eo.embed(x)
|
56 |
-
|
57 |
-
return embed, embedder_obj.out_dim
|
58 |
-
|
59 |
-
|
60 |
-
class SDFNetwork(nn.Module):
|
61 |
-
def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
|
62 |
-
scale=1, geometric_init=True, weight_norm=True, inside_outside=False):
|
63 |
-
super(SDFNetwork, self).__init__()
|
64 |
-
|
65 |
-
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
66 |
-
|
67 |
-
self.embed_fn_fine = None
|
68 |
-
|
69 |
-
if multires > 0:
|
70 |
-
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
|
71 |
-
self.embed_fn_fine = embed_fn
|
72 |
-
dims[0] = input_ch
|
73 |
-
|
74 |
-
self.num_layers = len(dims)
|
75 |
-
self.skip_in = skip_in
|
76 |
-
self.scale = scale
|
77 |
-
|
78 |
-
for l in range(0, self.num_layers - 1):
|
79 |
-
if l + 1 in self.skip_in:
|
80 |
-
out_dim = dims[l + 1] - dims[0]
|
81 |
-
else:
|
82 |
-
out_dim = dims[l + 1]
|
83 |
-
|
84 |
-
lin = nn.Linear(dims[l], out_dim)
|
85 |
-
|
86 |
-
if geometric_init:
|
87 |
-
if l == self.num_layers - 2:
|
88 |
-
if not inside_outside:
|
89 |
-
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
90 |
-
torch.nn.init.constant_(lin.bias, -bias)
|
91 |
-
else:
|
92 |
-
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
93 |
-
torch.nn.init.constant_(lin.bias, bias)
|
94 |
-
elif multires > 0 and l == 0:
|
95 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
96 |
-
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
97 |
-
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
98 |
-
elif multires > 0 and l in self.skip_in:
|
99 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
100 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
101 |
-
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
102 |
-
else:
|
103 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
104 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
105 |
-
|
106 |
-
if weight_norm:
|
107 |
-
lin = nn.utils.weight_norm(lin)
|
108 |
-
|
109 |
-
setattr(self, "lin" + str(l), lin)
|
110 |
-
|
111 |
-
self.activation = nn.Softplus(beta=100)
|
112 |
-
|
113 |
-
def forward(self, inputs):
|
114 |
-
inputs = inputs * self.scale
|
115 |
-
if self.embed_fn_fine is not None:
|
116 |
-
inputs = self.embed_fn_fine(inputs)
|
117 |
-
|
118 |
-
x = inputs
|
119 |
-
for l in range(0, self.num_layers - 1):
|
120 |
-
lin = getattr(self, "lin" + str(l))
|
121 |
-
|
122 |
-
if l in self.skip_in:
|
123 |
-
x = torch.cat([x, inputs], -1) / np.sqrt(2)
|
124 |
-
|
125 |
-
x = lin(x)
|
126 |
-
|
127 |
-
if l < self.num_layers - 2:
|
128 |
-
x = self.activation(x)
|
129 |
-
|
130 |
-
return x
|
131 |
-
|
132 |
-
def sdf(self, x):
|
133 |
-
return self.forward(x)[..., :1]
|
134 |
-
|
135 |
-
def sdf_hidden_appearance(self, x):
|
136 |
-
return self.forward(x)
|
137 |
-
|
138 |
-
def gradient(self, x):
|
139 |
-
x.requires_grad_(True)
|
140 |
-
with torch.enable_grad():
|
141 |
-
y = self.sdf(x)
|
142 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
143 |
-
gradients = torch.autograd.grad(
|
144 |
-
outputs=y,
|
145 |
-
inputs=x,
|
146 |
-
grad_outputs=d_output,
|
147 |
-
create_graph=True,
|
148 |
-
retain_graph=True,
|
149 |
-
only_inputs=True)[0]
|
150 |
-
return gradients
|
151 |
-
|
152 |
-
def sdf_normal(self, x):
|
153 |
-
x.requires_grad_(True)
|
154 |
-
with torch.enable_grad():
|
155 |
-
y = self.sdf(x)
|
156 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
157 |
-
gradients = torch.autograd.grad(
|
158 |
-
outputs=y,
|
159 |
-
inputs=x,
|
160 |
-
grad_outputs=d_output,
|
161 |
-
create_graph=True,
|
162 |
-
retain_graph=True,
|
163 |
-
only_inputs=True)[0]
|
164 |
-
return y[..., :1].detach(), gradients.detach()
|
165 |
-
|
166 |
-
class SDFNetworkWithFeature(nn.Module):
|
167 |
-
def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
|
168 |
-
scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5):
|
169 |
-
super().__init__()
|
170 |
-
|
171 |
-
self.register_buffer("cube", cube)
|
172 |
-
self.cube_length = cube_length
|
173 |
-
dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
174 |
-
|
175 |
-
self.embed_fn_fine = None
|
176 |
-
|
177 |
-
if multires > 0:
|
178 |
-
embed_fn, input_ch = get_embedder(multires, input_dims=dp_in)
|
179 |
-
self.embed_fn_fine = embed_fn
|
180 |
-
dims[0] = input_ch + df_in
|
181 |
-
|
182 |
-
self.num_layers = len(dims)
|
183 |
-
self.skip_in = skip_in
|
184 |
-
self.scale = scale
|
185 |
-
|
186 |
-
for l in range(0, self.num_layers - 1):
|
187 |
-
if l + 1 in self.skip_in:
|
188 |
-
out_dim = dims[l + 1] - dims[0]
|
189 |
-
else:
|
190 |
-
out_dim = dims[l + 1]
|
191 |
-
|
192 |
-
lin = nn.Linear(dims[l], out_dim)
|
193 |
-
|
194 |
-
if geometric_init:
|
195 |
-
if l == self.num_layers - 2:
|
196 |
-
if not inside_outside:
|
197 |
-
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
198 |
-
torch.nn.init.constant_(lin.bias, -bias)
|
199 |
-
else:
|
200 |
-
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
201 |
-
torch.nn.init.constant_(lin.bias, bias)
|
202 |
-
elif multires > 0 and l == 0:
|
203 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
204 |
-
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
205 |
-
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
206 |
-
elif multires > 0 and l in self.skip_in:
|
207 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
208 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
209 |
-
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
|
210 |
-
else:
|
211 |
-
torch.nn.init.constant_(lin.bias, 0.0)
|
212 |
-
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
213 |
-
|
214 |
-
if weight_norm:
|
215 |
-
lin = nn.utils.weight_norm(lin)
|
216 |
-
|
217 |
-
setattr(self, "lin" + str(l), lin)
|
218 |
-
|
219 |
-
self.activation = nn.Softplus(beta=100)
|
220 |
-
|
221 |
-
def forward(self, points):
|
222 |
-
points = points * self.scale
|
223 |
-
|
224 |
-
# note: point*2 because the cube is [-0.5,0.5]
|
225 |
-
with torch.no_grad():
|
226 |
-
feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach()
|
227 |
-
feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1)
|
228 |
-
if self.embed_fn_fine is not None:
|
229 |
-
points = self.embed_fn_fine(points)
|
230 |
-
|
231 |
-
x = torch.cat([points, feats], -1)
|
232 |
-
for l in range(0, self.num_layers - 1):
|
233 |
-
lin = getattr(self, "lin" + str(l))
|
234 |
-
|
235 |
-
if l in self.skip_in:
|
236 |
-
x = torch.cat([x, points, feats], -1) / np.sqrt(2)
|
237 |
-
|
238 |
-
x = lin(x)
|
239 |
-
|
240 |
-
if l < self.num_layers - 2:
|
241 |
-
x = self.activation(x)
|
242 |
-
|
243 |
-
# concat feats
|
244 |
-
x = torch.cat([x, feats], -1)
|
245 |
-
return x
|
246 |
-
|
247 |
-
def sdf(self, x):
|
248 |
-
return self.forward(x)[..., :1]
|
249 |
-
|
250 |
-
def sdf_hidden_appearance(self, x):
|
251 |
-
return self.forward(x)
|
252 |
-
|
253 |
-
def gradient(self, x):
|
254 |
-
x.requires_grad_(True)
|
255 |
-
with torch.enable_grad():
|
256 |
-
y = self.sdf(x)
|
257 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
258 |
-
gradients = torch.autograd.grad(
|
259 |
-
outputs=y,
|
260 |
-
inputs=x,
|
261 |
-
grad_outputs=d_output,
|
262 |
-
create_graph=True,
|
263 |
-
retain_graph=True,
|
264 |
-
only_inputs=True)[0]
|
265 |
-
return gradients
|
266 |
-
|
267 |
-
def sdf_normal(self, x):
|
268 |
-
x.requires_grad_(True)
|
269 |
-
with torch.enable_grad():
|
270 |
-
y = self.sdf(x)
|
271 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
272 |
-
gradients = torch.autograd.grad(
|
273 |
-
outputs=y,
|
274 |
-
inputs=x,
|
275 |
-
grad_outputs=d_output,
|
276 |
-
create_graph=True,
|
277 |
-
retain_graph=True,
|
278 |
-
only_inputs=True)[0]
|
279 |
-
return y[..., :1].detach(), gradients.detach()
|
280 |
-
|
281 |
-
|
282 |
-
class VanillaMLP(nn.Module):
|
283 |
-
def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers):
|
284 |
-
super().__init__()
|
285 |
-
self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers
|
286 |
-
self.sphere_init, self.weight_norm = True, True
|
287 |
-
self.sphere_init_radius = 0.5
|
288 |
-
self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
|
289 |
-
for i in range(self.n_hidden_layers - 1):
|
290 |
-
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
|
291 |
-
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
|
292 |
-
self.layers = nn.Sequential(*self.layers)
|
293 |
-
|
294 |
-
@torch.cuda.amp.autocast(False)
|
295 |
-
def forward(self, x):
|
296 |
-
x = self.layers(x.float())
|
297 |
-
return x
|
298 |
-
|
299 |
-
def make_linear(self, dim_in, dim_out, is_first, is_last):
|
300 |
-
layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
|
301 |
-
if self.sphere_init:
|
302 |
-
if is_last:
|
303 |
-
torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
|
304 |
-
torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
|
305 |
-
elif is_first:
|
306 |
-
torch.nn.init.constant_(layer.bias, 0.0)
|
307 |
-
torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
|
308 |
-
torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
|
309 |
-
else:
|
310 |
-
torch.nn.init.constant_(layer.bias, 0.0)
|
311 |
-
torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
|
312 |
-
else:
|
313 |
-
torch.nn.init.constant_(layer.bias, 0.0)
|
314 |
-
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
|
315 |
-
|
316 |
-
if self.weight_norm:
|
317 |
-
layer = nn.utils.weight_norm(layer)
|
318 |
-
return layer
|
319 |
-
|
320 |
-
def make_activation(self):
|
321 |
-
if self.sphere_init:
|
322 |
-
return nn.Softplus(beta=100)
|
323 |
-
else:
|
324 |
-
return nn.ReLU(inplace=True)
|
325 |
-
|
326 |
-
|
327 |
-
class SDFHashGridNetwork(nn.Module):
|
328 |
-
def __init__(self, bound=0.5, feats_dim=13):
|
329 |
-
super().__init__()
|
330 |
-
self.bound = bound
|
331 |
-
# max_resolution = 32
|
332 |
-
# base_resolution = 16
|
333 |
-
# n_levels = 4
|
334 |
-
# log2_hashmap_size = 16
|
335 |
-
# n_features_per_level = 8
|
336 |
-
max_resolution = 2048
|
337 |
-
base_resolution = 16
|
338 |
-
n_levels = 16
|
339 |
-
log2_hashmap_size = 19
|
340 |
-
n_features_per_level = 2
|
341 |
-
|
342 |
-
# max_res = base_res * t^(k-1)
|
343 |
-
per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1))
|
344 |
-
|
345 |
-
self.encoder = tcnn.Encoding(
|
346 |
-
n_input_dims=3,
|
347 |
-
encoding_config={
|
348 |
-
"otype": "HashGrid",
|
349 |
-
"n_levels": n_levels,
|
350 |
-
"n_features_per_level": n_features_per_level,
|
351 |
-
"log2_hashmap_size": log2_hashmap_size,
|
352 |
-
"base_resolution": base_resolution,
|
353 |
-
"per_level_scale": per_level_scale,
|
354 |
-
},
|
355 |
-
)
|
356 |
-
self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1)
|
357 |
-
|
358 |
-
def forward(self, x):
|
359 |
-
shape = x.shape[:-1]
|
360 |
-
x = x.reshape(-1, 3)
|
361 |
-
x_ = (x + self.bound) / (2 * self.bound)
|
362 |
-
feats = self.encoder(x_)
|
363 |
-
feats = torch.cat([x, feats], 1)
|
364 |
-
|
365 |
-
feats = self.sdf_mlp(feats)
|
366 |
-
feats = feats.reshape(*shape,-1)
|
367 |
-
return feats
|
368 |
-
|
369 |
-
def sdf(self, x):
|
370 |
-
return self(x)[...,:1]
|
371 |
-
|
372 |
-
def gradient(self, x):
|
373 |
-
x.requires_grad_(True)
|
374 |
-
with torch.enable_grad():
|
375 |
-
y = self.sdf(x)
|
376 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
377 |
-
gradients = torch.autograd.grad(
|
378 |
-
outputs=y,
|
379 |
-
inputs=x,
|
380 |
-
grad_outputs=d_output,
|
381 |
-
create_graph=True,
|
382 |
-
retain_graph=True,
|
383 |
-
only_inputs=True)[0]
|
384 |
-
return gradients
|
385 |
-
|
386 |
-
def sdf_normal(self, x):
|
387 |
-
x.requires_grad_(True)
|
388 |
-
with torch.enable_grad():
|
389 |
-
y = self.sdf(x)
|
390 |
-
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
391 |
-
gradients = torch.autograd.grad(
|
392 |
-
outputs=y,
|
393 |
-
inputs=x,
|
394 |
-
grad_outputs=d_output,
|
395 |
-
create_graph=True,
|
396 |
-
retain_graph=True,
|
397 |
-
only_inputs=True)[0]
|
398 |
-
return y[..., :1].detach(), gradients.detach()
|
399 |
-
|
400 |
-
class RenderingFFNetwork(nn.Module):
|
401 |
-
def __init__(self, in_feats_dim=12):
|
402 |
-
super().__init__()
|
403 |
-
self.dir_encoder = tcnn.Encoding(
|
404 |
-
n_input_dims=3,
|
405 |
-
encoding_config={
|
406 |
-
"otype": "SphericalHarmonics",
|
407 |
-
"degree": 4,
|
408 |
-
},
|
409 |
-
)
|
410 |
-
self.color_mlp = tcnn.Network(
|
411 |
-
n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims,
|
412 |
-
n_output_dims = 3,
|
413 |
-
network_config={
|
414 |
-
"otype": "FullyFusedMLP",
|
415 |
-
"activation": "ReLU",
|
416 |
-
"output_activation": "none",
|
417 |
-
"n_neurons": 64,
|
418 |
-
"n_hidden_layers": 2,
|
419 |
-
},
|
420 |
-
)
|
421 |
-
|
422 |
-
def forward(self, points, normals, view_dirs, feature_vectors):
|
423 |
-
normals = F.normalize(normals, dim=-1)
|
424 |
-
view_dirs = F.normalize(view_dirs, dim=-1)
|
425 |
-
reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs
|
426 |
-
|
427 |
-
x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1)
|
428 |
-
colors = self.color_mlp(x).float()
|
429 |
-
colors = F.sigmoid(colors)
|
430 |
-
return colors
|
431 |
-
|
432 |
-
# This implementation is borrowed from IDR: https://github.com/lioryariv/idr
|
433 |
-
class RenderingNetwork(nn.Module):
|
434 |
-
def __init__(self, d_feature, d_in, d_out, d_hidden,
|
435 |
-
n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True):
|
436 |
-
super().__init__()
|
437 |
-
|
438 |
-
self.squeeze_out = squeeze_out
|
439 |
-
self.rgb_act=F.sigmoid
|
440 |
-
self.use_view_dir=use_view_dir
|
441 |
-
|
442 |
-
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
|
443 |
-
|
444 |
-
self.embedview_fn = None
|
445 |
-
if multires_view > 0:
|
446 |
-
embedview_fn, input_ch = get_embedder(multires_view)
|
447 |
-
self.embedview_fn = embedview_fn
|
448 |
-
dims[0] += (input_ch - 3)
|
449 |
-
|
450 |
-
self.num_layers = len(dims)
|
451 |
-
|
452 |
-
for l in range(0, self.num_layers - 1):
|
453 |
-
out_dim = dims[l + 1]
|
454 |
-
lin = nn.Linear(dims[l], out_dim)
|
455 |
-
|
456 |
-
if weight_norm:
|
457 |
-
lin = nn.utils.weight_norm(lin)
|
458 |
-
|
459 |
-
setattr(self, "lin" + str(l), lin)
|
460 |
-
|
461 |
-
self.relu = nn.ReLU()
|
462 |
-
|
463 |
-
def forward(self, points, normals, view_dirs, feature_vectors):
|
464 |
-
if self.use_view_dir:
|
465 |
-
view_dirs = F.normalize(view_dirs, dim=-1)
|
466 |
-
normals = F.normalize(normals, dim=-1)
|
467 |
-
reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs
|
468 |
-
if self.embedview_fn is not None: reflective = self.embedview_fn(reflective)
|
469 |
-
rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1)
|
470 |
-
else:
|
471 |
-
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
|
472 |
-
|
473 |
-
x = rendering_input
|
474 |
-
|
475 |
-
for l in range(0, self.num_layers - 1):
|
476 |
-
lin = getattr(self, "lin" + str(l))
|
477 |
-
|
478 |
-
x = lin(x)
|
479 |
-
|
480 |
-
if l < self.num_layers - 2:
|
481 |
-
x = self.relu(x)
|
482 |
-
|
483 |
-
if self.squeeze_out:
|
484 |
-
x = self.rgb_act(x)
|
485 |
-
return x
|
486 |
-
|
487 |
-
|
488 |
-
class SingleVarianceNetwork(nn.Module):
|
489 |
-
def __init__(self, init_val, activation='exp'):
|
490 |
-
super(SingleVarianceNetwork, self).__init__()
|
491 |
-
self.act = activation
|
492 |
-
self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
|
493 |
-
|
494 |
-
def forward(self, x):
|
495 |
-
device = x.device
|
496 |
-
if self.act=='exp':
|
497 |
-
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0)
|
498 |
-
else:
|
499 |
-
raise NotImplementedError
|
500 |
-
|
501 |
-
def warp(self, x, inv_s):
|
502 |
-
device = x.device
|
503 |
-
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/ngp_renderer.py
DELETED
@@ -1,721 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import trimesh
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from packaging import version as pver
|
9 |
-
|
10 |
-
import tinycudann as tcnn
|
11 |
-
from torch.autograd import Function
|
12 |
-
|
13 |
-
from torch.cuda.amp import custom_bwd, custom_fwd
|
14 |
-
|
15 |
-
import raymarching
|
16 |
-
|
17 |
-
def custom_meshgrid(*args):
|
18 |
-
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
19 |
-
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
20 |
-
return torch.meshgrid(*args)
|
21 |
-
else:
|
22 |
-
return torch.meshgrid(*args, indexing='ij')
|
23 |
-
|
24 |
-
def sample_pdf(bins, weights, n_samples, det=False):
|
25 |
-
# This implementation is from NeRF
|
26 |
-
# bins: [B, T], old_z_vals
|
27 |
-
# weights: [B, T - 1], bin weights.
|
28 |
-
# return: [B, n_samples], new_z_vals
|
29 |
-
|
30 |
-
# Get pdf
|
31 |
-
weights = weights + 1e-5 # prevent nans
|
32 |
-
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
33 |
-
cdf = torch.cumsum(pdf, -1)
|
34 |
-
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
35 |
-
# Take uniform samples
|
36 |
-
if det:
|
37 |
-
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
|
38 |
-
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
39 |
-
else:
|
40 |
-
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
|
41 |
-
|
42 |
-
# Invert CDF
|
43 |
-
u = u.contiguous()
|
44 |
-
inds = torch.searchsorted(cdf, u, right=True)
|
45 |
-
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
46 |
-
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
47 |
-
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
|
48 |
-
|
49 |
-
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
50 |
-
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
51 |
-
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
52 |
-
|
53 |
-
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
54 |
-
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
55 |
-
t = (u - cdf_g[..., 0]) / denom
|
56 |
-
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
57 |
-
|
58 |
-
return samples
|
59 |
-
|
60 |
-
|
61 |
-
def plot_pointcloud(pc, color=None):
|
62 |
-
# pc: [N, 3]
|
63 |
-
# color: [N, 3/4]
|
64 |
-
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
|
65 |
-
pc = trimesh.PointCloud(pc, color)
|
66 |
-
# axis
|
67 |
-
axes = trimesh.creation.axis(axis_length=4)
|
68 |
-
# sphere
|
69 |
-
sphere = trimesh.creation.icosphere(radius=1)
|
70 |
-
trimesh.Scene([pc, axes, sphere]).show()
|
71 |
-
|
72 |
-
|
73 |
-
class NGPRenderer(nn.Module):
|
74 |
-
def __init__(self,
|
75 |
-
bound=1,
|
76 |
-
cuda_ray=True,
|
77 |
-
density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
|
78 |
-
min_near=0.2,
|
79 |
-
density_thresh=0.01,
|
80 |
-
bg_radius=-1,
|
81 |
-
):
|
82 |
-
super().__init__()
|
83 |
-
|
84 |
-
self.bound = bound
|
85 |
-
self.cascade = 1
|
86 |
-
self.grid_size = 128
|
87 |
-
self.density_scale = density_scale
|
88 |
-
self.min_near = min_near
|
89 |
-
self.density_thresh = density_thresh
|
90 |
-
self.bg_radius = bg_radius # radius of the background sphere.
|
91 |
-
|
92 |
-
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
|
93 |
-
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
|
94 |
-
aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
|
95 |
-
aabb_infer = aabb_train.clone()
|
96 |
-
self.register_buffer('aabb_train', aabb_train)
|
97 |
-
self.register_buffer('aabb_infer', aabb_infer)
|
98 |
-
|
99 |
-
# extra state for cuda raymarching
|
100 |
-
self.cuda_ray = cuda_ray
|
101 |
-
if cuda_ray:
|
102 |
-
# density grid
|
103 |
-
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
104 |
-
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
105 |
-
self.register_buffer('density_grid', density_grid)
|
106 |
-
self.register_buffer('density_bitfield', density_bitfield)
|
107 |
-
self.mean_density = 0
|
108 |
-
self.iter_density = 0
|
109 |
-
# step counter
|
110 |
-
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
|
111 |
-
self.register_buffer('step_counter', step_counter)
|
112 |
-
self.mean_count = 0
|
113 |
-
self.local_step = 0
|
114 |
-
|
115 |
-
def forward(self, x, d):
|
116 |
-
raise NotImplementedError()
|
117 |
-
|
118 |
-
# separated density and color query (can accelerate non-cuda-ray mode.)
|
119 |
-
def density(self, x):
|
120 |
-
raise NotImplementedError()
|
121 |
-
|
122 |
-
def color(self, x, d, mask=None, **kwargs):
|
123 |
-
raise NotImplementedError()
|
124 |
-
|
125 |
-
def reset_extra_state(self):
|
126 |
-
if not self.cuda_ray:
|
127 |
-
return
|
128 |
-
# density grid
|
129 |
-
self.density_grid.zero_()
|
130 |
-
self.mean_density = 0
|
131 |
-
self.iter_density = 0
|
132 |
-
# step counter
|
133 |
-
self.step_counter.zero_()
|
134 |
-
self.mean_count = 0
|
135 |
-
self.local_step = 0
|
136 |
-
|
137 |
-
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs):
|
138 |
-
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
139 |
-
# bg_color: [3] in range [0, 1]
|
140 |
-
# return: image: [B, N, 3], depth: [B, N]
|
141 |
-
|
142 |
-
prefix = rays_o.shape[:-1]
|
143 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
144 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
145 |
-
|
146 |
-
N = rays_o.shape[0] # N = B * N, in fact
|
147 |
-
device = rays_o.device
|
148 |
-
|
149 |
-
# choose aabb
|
150 |
-
aabb = self.aabb_train if self.training else self.aabb_infer
|
151 |
-
|
152 |
-
# sample steps
|
153 |
-
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
|
154 |
-
nears.unsqueeze_(-1)
|
155 |
-
fars.unsqueeze_(-1)
|
156 |
-
|
157 |
-
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
|
158 |
-
|
159 |
-
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
|
160 |
-
z_vals = z_vals.expand((N, num_steps)) # [N, T]
|
161 |
-
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
|
162 |
-
|
163 |
-
# perturb z_vals
|
164 |
-
sample_dist = (fars - nears) / num_steps
|
165 |
-
if perturb:
|
166 |
-
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
|
167 |
-
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
|
168 |
-
|
169 |
-
# generate xyzs
|
170 |
-
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
|
171 |
-
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
172 |
-
|
173 |
-
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
174 |
-
|
175 |
-
# query SDF and RGB
|
176 |
-
density_outputs = self.density(xyzs.reshape(-1, 3))
|
177 |
-
|
178 |
-
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
|
179 |
-
for k, v in density_outputs.items():
|
180 |
-
density_outputs[k] = v.view(N, num_steps, -1)
|
181 |
-
|
182 |
-
# upsample z_vals (nerf-like)
|
183 |
-
if upsample_steps > 0:
|
184 |
-
with torch.no_grad():
|
185 |
-
|
186 |
-
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
|
187 |
-
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
188 |
-
|
189 |
-
alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T]
|
190 |
-
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
|
191 |
-
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
|
192 |
-
|
193 |
-
# sample new z_vals
|
194 |
-
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
|
195 |
-
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
|
196 |
-
|
197 |
-
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
|
198 |
-
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
199 |
-
|
200 |
-
# only forward new points to save computation
|
201 |
-
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
|
202 |
-
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
|
203 |
-
for k, v in new_density_outputs.items():
|
204 |
-
new_density_outputs[k] = v.view(N, upsample_steps, -1)
|
205 |
-
|
206 |
-
# re-order
|
207 |
-
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
|
208 |
-
z_vals, z_index = torch.sort(z_vals, dim=1)
|
209 |
-
|
210 |
-
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
|
211 |
-
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
|
212 |
-
|
213 |
-
for k in density_outputs:
|
214 |
-
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
|
215 |
-
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
|
216 |
-
|
217 |
-
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
|
218 |
-
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
219 |
-
alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
|
220 |
-
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
|
221 |
-
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
|
222 |
-
|
223 |
-
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
|
224 |
-
for k, v in density_outputs.items():
|
225 |
-
density_outputs[k] = v.view(-1, v.shape[-1])
|
226 |
-
|
227 |
-
mask = weights > 1e-4 # hard coded
|
228 |
-
rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs)
|
229 |
-
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
|
230 |
-
|
231 |
-
#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
|
232 |
-
|
233 |
-
# calculate weight_sum (mask)
|
234 |
-
weights_sum = weights.sum(dim=-1) # [N]
|
235 |
-
|
236 |
-
# calculate depth
|
237 |
-
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
|
238 |
-
depth = torch.sum(weights * ori_z_vals, dim=-1)
|
239 |
-
|
240 |
-
# calculate color
|
241 |
-
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
|
242 |
-
|
243 |
-
# mix background color
|
244 |
-
if self.bg_radius > 0:
|
245 |
-
# use the bg model to calculate bg_color
|
246 |
-
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
247 |
-
bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3]
|
248 |
-
elif bg_color is None:
|
249 |
-
bg_color = 1
|
250 |
-
|
251 |
-
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
252 |
-
|
253 |
-
image = image.view(*prefix, 3)
|
254 |
-
depth = depth.view(*prefix)
|
255 |
-
|
256 |
-
# tmp: reg loss in mip-nerf 360
|
257 |
-
# z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)
|
258 |
-
# mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]
|
259 |
-
# loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()
|
260 |
-
|
261 |
-
return {
|
262 |
-
'depth': depth,
|
263 |
-
'image': image,
|
264 |
-
'weights_sum': weights_sum,
|
265 |
-
}
|
266 |
-
|
267 |
-
|
268 |
-
def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
|
269 |
-
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
270 |
-
# return: image: [B, N, 3], depth: [B, N]
|
271 |
-
|
272 |
-
prefix = rays_o.shape[:-1]
|
273 |
-
rays_o = rays_o.contiguous().view(-1, 3)
|
274 |
-
rays_d = rays_d.contiguous().view(-1, 3)
|
275 |
-
|
276 |
-
N = rays_o.shape[0] # N = B * N, in fact
|
277 |
-
device = rays_o.device
|
278 |
-
|
279 |
-
# pre-calculate near far
|
280 |
-
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near)
|
281 |
-
|
282 |
-
# mix background color
|
283 |
-
if self.bg_radius > 0:
|
284 |
-
# use the bg model to calculate bg_color
|
285 |
-
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
286 |
-
bg_color = self.background(sph, rays_d) # [N, 3]
|
287 |
-
elif bg_color is None:
|
288 |
-
bg_color = 1
|
289 |
-
|
290 |
-
results = {}
|
291 |
-
|
292 |
-
if self.training:
|
293 |
-
# setup counter
|
294 |
-
counter = self.step_counter[self.local_step % 16]
|
295 |
-
counter.zero_() # set to 0
|
296 |
-
self.local_step += 1
|
297 |
-
|
298 |
-
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
|
299 |
-
|
300 |
-
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
301 |
-
|
302 |
-
sigmas, rgbs = self(xyzs, dirs)
|
303 |
-
sigmas = self.density_scale * sigmas
|
304 |
-
|
305 |
-
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
|
306 |
-
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
307 |
-
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
308 |
-
image = image.view(*prefix, 3)
|
309 |
-
depth = depth.view(*prefix)
|
310 |
-
|
311 |
-
else:
|
312 |
-
|
313 |
-
# allocate outputs
|
314 |
-
# if use autocast, must init as half so it won't be autocasted and lose reference.
|
315 |
-
#dtype = torch.half if torch.is_autocast_enabled() else torch.float32
|
316 |
-
# output should always be float32! only network inference uses half.
|
317 |
-
dtype = torch.float32
|
318 |
-
|
319 |
-
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
320 |
-
depth = torch.zeros(N, dtype=dtype, device=device)
|
321 |
-
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
322 |
-
|
323 |
-
n_alive = N
|
324 |
-
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
325 |
-
rays_t = nears.clone() # [N]
|
326 |
-
|
327 |
-
step = 0
|
328 |
-
|
329 |
-
while step < max_steps:
|
330 |
-
|
331 |
-
# count alive rays
|
332 |
-
n_alive = rays_alive.shape[0]
|
333 |
-
|
334 |
-
# exit loop
|
335 |
-
if n_alive <= 0:
|
336 |
-
break
|
337 |
-
|
338 |
-
# decide compact_steps
|
339 |
-
n_step = max(min(N // n_alive, 8), 1)
|
340 |
-
|
341 |
-
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
|
342 |
-
|
343 |
-
sigmas, rgbs = self(xyzs, dirs)
|
344 |
-
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
|
345 |
-
# sigmas = density_outputs['sigma']
|
346 |
-
# rgbs = self.color(xyzs, dirs, **density_outputs)
|
347 |
-
sigmas = self.density_scale * sigmas
|
348 |
-
|
349 |
-
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
|
350 |
-
|
351 |
-
rays_alive = rays_alive[rays_alive >= 0]
|
352 |
-
|
353 |
-
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
|
354 |
-
|
355 |
-
step += n_step
|
356 |
-
|
357 |
-
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
358 |
-
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
359 |
-
image = image.view(*prefix, 3)
|
360 |
-
depth = depth.view(*prefix)
|
361 |
-
|
362 |
-
results['weights_sum'] = weights_sum
|
363 |
-
results['depth'] = depth
|
364 |
-
results['image'] = image
|
365 |
-
|
366 |
-
return results
|
367 |
-
|
368 |
-
@torch.no_grad()
|
369 |
-
def mark_untrained_grid(self, poses, intrinsic, S=64):
|
370 |
-
# poses: [B, 4, 4]
|
371 |
-
# intrinsic: [3, 3]
|
372 |
-
|
373 |
-
if not self.cuda_ray:
|
374 |
-
return
|
375 |
-
|
376 |
-
if isinstance(poses, np.ndarray):
|
377 |
-
poses = torch.from_numpy(poses)
|
378 |
-
|
379 |
-
B = poses.shape[0]
|
380 |
-
|
381 |
-
fx, fy, cx, cy = intrinsic
|
382 |
-
|
383 |
-
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
384 |
-
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
385 |
-
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
386 |
-
|
387 |
-
count = torch.zeros_like(self.density_grid)
|
388 |
-
poses = poses.to(count.device)
|
389 |
-
|
390 |
-
# 5-level loop, forgive me...
|
391 |
-
|
392 |
-
for xs in X:
|
393 |
-
for ys in Y:
|
394 |
-
for zs in Z:
|
395 |
-
|
396 |
-
# construct points
|
397 |
-
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
398 |
-
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
399 |
-
indices = raymarching.morton3D(coords).long() # [N]
|
400 |
-
world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1]
|
401 |
-
|
402 |
-
# cascading
|
403 |
-
for cas in range(self.cascade):
|
404 |
-
bound = min(2 ** cas, self.bound)
|
405 |
-
half_grid_size = bound / self.grid_size
|
406 |
-
# scale to current cascade's resolution
|
407 |
-
cas_world_xyzs = world_xyzs * (bound - half_grid_size)
|
408 |
-
|
409 |
-
# split batch to avoid OOM
|
410 |
-
head = 0
|
411 |
-
while head < B:
|
412 |
-
tail = min(head + S, B)
|
413 |
-
|
414 |
-
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
|
415 |
-
cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1)
|
416 |
-
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
|
417 |
-
|
418 |
-
# query if point is covered by any camera
|
419 |
-
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
|
420 |
-
mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
|
421 |
-
mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
|
422 |
-
mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
|
423 |
-
|
424 |
-
# update count
|
425 |
-
count[cas, indices] += mask
|
426 |
-
head += S
|
427 |
-
|
428 |
-
# mark untrained grid as -1
|
429 |
-
self.density_grid[count == 0] = -1
|
430 |
-
|
431 |
-
print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}')
|
432 |
-
|
433 |
-
@torch.no_grad()
|
434 |
-
def update_extra_state(self, decay=0.95, S=128):
|
435 |
-
# call before each epoch to update extra states.
|
436 |
-
|
437 |
-
if not self.cuda_ray:
|
438 |
-
return
|
439 |
-
|
440 |
-
### update density grid
|
441 |
-
tmp_grid = - torch.ones_like(self.density_grid)
|
442 |
-
|
443 |
-
# full update.
|
444 |
-
if self.iter_density < 16:
|
445 |
-
#if True:
|
446 |
-
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
447 |
-
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
448 |
-
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
449 |
-
|
450 |
-
for xs in X:
|
451 |
-
for ys in Y:
|
452 |
-
for zs in Z:
|
453 |
-
|
454 |
-
# construct points
|
455 |
-
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
456 |
-
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
457 |
-
indices = raymarching.morton3D(coords).long() # [N]
|
458 |
-
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
459 |
-
|
460 |
-
# cascading
|
461 |
-
for cas in range(self.cascade):
|
462 |
-
bound = min(2 ** cas, self.bound)
|
463 |
-
half_grid_size = bound / self.grid_size
|
464 |
-
# scale to current cascade's resolution
|
465 |
-
cas_xyzs = xyzs * (bound - half_grid_size)
|
466 |
-
# add noise in [-hgs, hgs]
|
467 |
-
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
468 |
-
# query density
|
469 |
-
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
470 |
-
sigmas *= self.density_scale
|
471 |
-
# assign
|
472 |
-
tmp_grid[cas, indices] = sigmas
|
473 |
-
|
474 |
-
# partial update (half the computation)
|
475 |
-
# TODO: why no need of maxpool ?
|
476 |
-
else:
|
477 |
-
N = self.grid_size ** 3 // 4 # H * H * H / 4
|
478 |
-
for cas in range(self.cascade):
|
479 |
-
# random sample some positions
|
480 |
-
coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128)
|
481 |
-
indices = raymarching.morton3D(coords).long() # [N]
|
482 |
-
# random sample occupied positions
|
483 |
-
occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz]
|
484 |
-
rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device)
|
485 |
-
occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication
|
486 |
-
occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3]
|
487 |
-
# concat
|
488 |
-
indices = torch.cat([indices, occ_indices], dim=0)
|
489 |
-
coords = torch.cat([coords, occ_coords], dim=0)
|
490 |
-
# same below
|
491 |
-
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
492 |
-
bound = min(2 ** cas, self.bound)
|
493 |
-
half_grid_size = bound / self.grid_size
|
494 |
-
# scale to current cascade's resolution
|
495 |
-
cas_xyzs = xyzs * (bound - half_grid_size)
|
496 |
-
# add noise in [-hgs, hgs]
|
497 |
-
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
498 |
-
# query density
|
499 |
-
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
500 |
-
sigmas *= self.density_scale
|
501 |
-
# assign
|
502 |
-
tmp_grid[cas, indices] = sigmas
|
503 |
-
|
504 |
-
## max-pool on tmp_grid for less aggressive culling [No significant improvement...]
|
505 |
-
# invalid_mask = tmp_grid < 0
|
506 |
-
# tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)
|
507 |
-
# tmp_grid[invalid_mask] = -1
|
508 |
-
|
509 |
-
# ema update
|
510 |
-
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
|
511 |
-
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
512 |
-
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density.
|
513 |
-
#self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions
|
514 |
-
self.iter_density += 1
|
515 |
-
|
516 |
-
# convert to bitfield
|
517 |
-
density_thresh = min(self.mean_density, self.density_thresh)
|
518 |
-
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
|
519 |
-
|
520 |
-
### update step counter
|
521 |
-
total_step = min(16, self.local_step)
|
522 |
-
if total_step > 0:
|
523 |
-
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
|
524 |
-
self.local_step = 0
|
525 |
-
|
526 |
-
#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
|
527 |
-
|
528 |
-
|
529 |
-
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
|
530 |
-
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
531 |
-
# return: pred_rgb: [B, N, 3]
|
532 |
-
|
533 |
-
if self.cuda_ray:
|
534 |
-
_run = self.run_cuda
|
535 |
-
else:
|
536 |
-
_run = self.run
|
537 |
-
|
538 |
-
results = _run(rays_o, rays_d, **kwargs)
|
539 |
-
return results
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
class _trunc_exp(Function):
|
544 |
-
@staticmethod
|
545 |
-
@custom_fwd(cast_inputs=torch.float32) # cast to float32
|
546 |
-
def forward(ctx, x):
|
547 |
-
ctx.save_for_backward(x)
|
548 |
-
return torch.exp(x)
|
549 |
-
|
550 |
-
@staticmethod
|
551 |
-
@custom_bwd
|
552 |
-
def backward(ctx, g):
|
553 |
-
x = ctx.saved_tensors[0]
|
554 |
-
return g * torch.exp(x.clamp(-15, 15))
|
555 |
-
|
556 |
-
trunc_exp = _trunc_exp.apply
|
557 |
-
|
558 |
-
class NGPNetwork(NGPRenderer):
|
559 |
-
def __init__(self,
|
560 |
-
num_layers=2,
|
561 |
-
hidden_dim=64,
|
562 |
-
geo_feat_dim=15,
|
563 |
-
num_layers_color=3,
|
564 |
-
hidden_dim_color=64,
|
565 |
-
bound=0.5,
|
566 |
-
max_resolution=128,
|
567 |
-
base_resolution=16,
|
568 |
-
n_levels=16,
|
569 |
-
**kwargs
|
570 |
-
):
|
571 |
-
super().__init__(bound, **kwargs)
|
572 |
-
|
573 |
-
# sigma network
|
574 |
-
self.num_layers = num_layers
|
575 |
-
self.hidden_dim = hidden_dim
|
576 |
-
self.geo_feat_dim = geo_feat_dim
|
577 |
-
self.bound = bound
|
578 |
-
|
579 |
-
log2_hashmap_size = 19
|
580 |
-
n_features_per_level = 2
|
581 |
-
|
582 |
-
|
583 |
-
per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1))
|
584 |
-
|
585 |
-
self.encoder = tcnn.Encoding(
|
586 |
-
n_input_dims=3,
|
587 |
-
encoding_config={
|
588 |
-
"otype": "HashGrid",
|
589 |
-
"n_levels": n_levels,
|
590 |
-
"n_features_per_level": n_features_per_level,
|
591 |
-
"log2_hashmap_size": log2_hashmap_size,
|
592 |
-
"base_resolution": base_resolution,
|
593 |
-
"per_level_scale": per_level_scale,
|
594 |
-
},
|
595 |
-
)
|
596 |
-
|
597 |
-
self.sigma_net = tcnn.Network(
|
598 |
-
n_input_dims = n_levels * 2,
|
599 |
-
n_output_dims=1 + self.geo_feat_dim,
|
600 |
-
network_config={
|
601 |
-
"otype": "FullyFusedMLP",
|
602 |
-
"activation": "ReLU",
|
603 |
-
"output_activation": "None",
|
604 |
-
"n_neurons": hidden_dim,
|
605 |
-
"n_hidden_layers": num_layers - 1,
|
606 |
-
},
|
607 |
-
)
|
608 |
-
|
609 |
-
# color network
|
610 |
-
self.num_layers_color = num_layers_color
|
611 |
-
self.hidden_dim_color = hidden_dim_color
|
612 |
-
|
613 |
-
self.encoder_dir = tcnn.Encoding(
|
614 |
-
n_input_dims=3,
|
615 |
-
encoding_config={
|
616 |
-
"otype": "SphericalHarmonics",
|
617 |
-
"degree": 4,
|
618 |
-
},
|
619 |
-
)
|
620 |
-
|
621 |
-
self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim
|
622 |
-
|
623 |
-
self.color_net = tcnn.Network(
|
624 |
-
n_input_dims = self.in_dim_color,
|
625 |
-
n_output_dims=3,
|
626 |
-
network_config={
|
627 |
-
"otype": "FullyFusedMLP",
|
628 |
-
"activation": "ReLU",
|
629 |
-
"output_activation": "None",
|
630 |
-
"n_neurons": hidden_dim_color,
|
631 |
-
"n_hidden_layers": num_layers_color - 1,
|
632 |
-
},
|
633 |
-
)
|
634 |
-
self.density_scale, self.density_std = 10.0, 0.25
|
635 |
-
|
636 |
-
def forward(self, x, d):
|
637 |
-
# x: [N, 3], in [-bound, bound]
|
638 |
-
# d: [N, 3], nomalized in [-1, 1]
|
639 |
-
|
640 |
-
|
641 |
-
# sigma
|
642 |
-
x_raw = x
|
643 |
-
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
644 |
-
x = self.encoder(x)
|
645 |
-
h = self.sigma_net(x)
|
646 |
-
|
647 |
-
# sigma = F.relu(h[..., 0])
|
648 |
-
density = h[..., 0]
|
649 |
-
# add density bias
|
650 |
-
dist = torch.norm(x_raw, dim=-1)
|
651 |
-
density_bias = (1 - dist / self.density_std) * self.density_scale
|
652 |
-
density = density_bias + density
|
653 |
-
sigma = F.softplus(density)
|
654 |
-
geo_feat = h[..., 1:]
|
655 |
-
|
656 |
-
# color
|
657 |
-
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
|
658 |
-
d = self.encoder_dir(d)
|
659 |
-
|
660 |
-
# p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
|
661 |
-
h = torch.cat([d, geo_feat], dim=-1)
|
662 |
-
h = self.color_net(h)
|
663 |
-
|
664 |
-
# sigmoid activation for rgb
|
665 |
-
color = torch.sigmoid(h)
|
666 |
-
|
667 |
-
return sigma, color
|
668 |
-
|
669 |
-
def density(self, x):
|
670 |
-
# x: [N, 3], in [-bound, bound]
|
671 |
-
x_raw = x
|
672 |
-
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
673 |
-
x = self.encoder(x)
|
674 |
-
h = self.sigma_net(x)
|
675 |
-
|
676 |
-
# sigma = F.relu(h[..., 0])
|
677 |
-
density = h[..., 0]
|
678 |
-
# add density bias
|
679 |
-
dist = torch.norm(x_raw, dim=-1)
|
680 |
-
density_bias = (1 - dist / self.density_std) * self.density_scale
|
681 |
-
density = density_bias + density
|
682 |
-
sigma = F.softplus(density)
|
683 |
-
geo_feat = h[..., 1:]
|
684 |
-
|
685 |
-
return {
|
686 |
-
'sigma': sigma,
|
687 |
-
'geo_feat': geo_feat,
|
688 |
-
}
|
689 |
-
|
690 |
-
# allow masked inference
|
691 |
-
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
|
692 |
-
# x: [N, 3] in [-bound, bound]
|
693 |
-
# mask: [N,], bool, indicates where we actually needs to compute rgb.
|
694 |
-
|
695 |
-
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
696 |
-
|
697 |
-
if mask is not None:
|
698 |
-
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
|
699 |
-
# in case of empty mask
|
700 |
-
if not mask.any():
|
701 |
-
return rgbs
|
702 |
-
x = x[mask]
|
703 |
-
d = d[mask]
|
704 |
-
geo_feat = geo_feat[mask]
|
705 |
-
|
706 |
-
# color
|
707 |
-
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
|
708 |
-
d = self.encoder_dir(d)
|
709 |
-
|
710 |
-
h = torch.cat([d, geo_feat], dim=-1)
|
711 |
-
h = self.color_net(h)
|
712 |
-
|
713 |
-
# sigmoid activation for rgb
|
714 |
-
h = torch.sigmoid(h)
|
715 |
-
|
716 |
-
if mask is not None:
|
717 |
-
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
|
718 |
-
else:
|
719 |
-
rgbs = h
|
720 |
-
|
721 |
-
return rgbs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
renderer/renderer.py
DELETED
@@ -1,604 +0,0 @@
|
|
1 |
-
import abc
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import numpy as np
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import torch.nn.functional as F
|
11 |
-
from omegaconf import OmegaConf
|
12 |
-
|
13 |
-
from skimage.io import imread, imsave
|
14 |
-
from PIL import Image
|
15 |
-
from torch.optim.lr_scheduler import LambdaLR
|
16 |
-
|
17 |
-
from ldm.base_utils import read_pickle, concat_images_list
|
18 |
-
from renderer.neus_networks import SDFNetwork, RenderingNetwork, SingleVarianceNetwork, SDFHashGridNetwork, RenderingFFNetwork
|
19 |
-
from renderer.ngp_renderer import NGPNetwork
|
20 |
-
from ldm.util import instantiate_from_config
|
21 |
-
|
22 |
-
DEFAULT_RADIUS = np.sqrt(3)/2
|
23 |
-
DEFAULT_SIDE_LENGTH = 0.6
|
24 |
-
|
25 |
-
def sample_pdf(bins, weights, n_samples, det=True):
|
26 |
-
device = bins.device
|
27 |
-
dtype = bins.dtype
|
28 |
-
# This implementation is from NeRF
|
29 |
-
# Get pdf
|
30 |
-
weights = weights + 1e-5 # prevent nans
|
31 |
-
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
32 |
-
cdf = torch.cumsum(pdf, -1)
|
33 |
-
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
34 |
-
# Take uniform samples
|
35 |
-
if det:
|
36 |
-
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples, dtype=dtype, device=device)
|
37 |
-
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
38 |
-
else:
|
39 |
-
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], dtype=dtype, device=device)
|
40 |
-
|
41 |
-
# Invert CDF
|
42 |
-
u = u.contiguous()
|
43 |
-
inds = torch.searchsorted(cdf, u, right=True)
|
44 |
-
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
45 |
-
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
46 |
-
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
47 |
-
|
48 |
-
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
49 |
-
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
50 |
-
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
51 |
-
|
52 |
-
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
53 |
-
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
54 |
-
t = (u - cdf_g[..., 0]) / denom
|
55 |
-
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
56 |
-
|
57 |
-
return samples
|
58 |
-
|
59 |
-
def near_far_from_sphere(rays_o, rays_d, radius=DEFAULT_RADIUS):
|
60 |
-
a = torch.sum(rays_d ** 2, dim=-1, keepdim=True)
|
61 |
-
b = torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
|
62 |
-
mid = -b / a
|
63 |
-
near = mid - radius
|
64 |
-
far = mid + radius
|
65 |
-
return near, far
|
66 |
-
|
67 |
-
class BackgroundRemoval:
|
68 |
-
def __init__(self, device='cuda'):
|
69 |
-
from carvekit.api.high import HiInterface
|
70 |
-
self.interface = HiInterface(
|
71 |
-
object_type="object", # Can be "object" or "hairs-like".
|
72 |
-
batch_size_seg=5,
|
73 |
-
batch_size_matting=1,
|
74 |
-
device=device,
|
75 |
-
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
76 |
-
matting_mask_size=2048,
|
77 |
-
trimap_prob_threshold=231,
|
78 |
-
trimap_dilation=30,
|
79 |
-
trimap_erosion_iters=5,
|
80 |
-
fp16=True,
|
81 |
-
)
|
82 |
-
|
83 |
-
@torch.no_grad()
|
84 |
-
def __call__(self, image):
|
85 |
-
# image: [H, W, 3] array in [0, 255].
|
86 |
-
image = Image.fromarray(image)
|
87 |
-
image = self.interface([image])[0]
|
88 |
-
image = np.array(image)
|
89 |
-
return image
|
90 |
-
|
91 |
-
|
92 |
-
class BaseRenderer(nn.Module):
|
93 |
-
def __init__(self, train_batch_num, test_batch_num):
|
94 |
-
super().__init__()
|
95 |
-
self.train_batch_num = train_batch_num
|
96 |
-
self.test_batch_num = test_batch_num
|
97 |
-
|
98 |
-
@abc.abstractmethod
|
99 |
-
def render_impl(self, ray_batch, is_train, step):
|
100 |
-
pass
|
101 |
-
|
102 |
-
@abc.abstractmethod
|
103 |
-
def render_with_loss(self, ray_batch, is_train, step):
|
104 |
-
pass
|
105 |
-
|
106 |
-
def render(self, ray_batch, is_train, step):
|
107 |
-
batch_num = self.train_batch_num if is_train else self.test_batch_num
|
108 |
-
ray_num = ray_batch['rays_o'].shape[0]
|
109 |
-
outputs = {}
|
110 |
-
for ri in range(0, ray_num, batch_num):
|
111 |
-
cur_ray_batch = {}
|
112 |
-
for k, v in ray_batch.items():
|
113 |
-
cur_ray_batch[k] = v[ri:ri + batch_num]
|
114 |
-
cur_outputs = self.render_impl(cur_ray_batch, is_train, step)
|
115 |
-
for k, v in cur_outputs.items():
|
116 |
-
if k not in outputs: outputs[k] = []
|
117 |
-
outputs[k].append(v)
|
118 |
-
|
119 |
-
for k, v in outputs.items():
|
120 |
-
outputs[k] = torch.cat(v, 0)
|
121 |
-
return outputs
|
122 |
-
|
123 |
-
|
124 |
-
class NeuSRenderer(BaseRenderer):
|
125 |
-
def __init__(self, train_batch_num, test_batch_num, lambda_eikonal_loss=0.1, use_mask=True,
|
126 |
-
lambda_rgb_loss=1.0, lambda_mask_loss=0.0, rgb_loss='soft_l1', coarse_sn=64, fine_sn=64):
|
127 |
-
super().__init__(train_batch_num, test_batch_num)
|
128 |
-
self.n_samples = coarse_sn
|
129 |
-
self.n_importance = fine_sn
|
130 |
-
self.up_sample_steps = 4
|
131 |
-
self.anneal_end = 200
|
132 |
-
self.use_mask = use_mask
|
133 |
-
self.lambda_eikonal_loss = lambda_eikonal_loss
|
134 |
-
self.lambda_rgb_loss = lambda_rgb_loss
|
135 |
-
self.lambda_mask_loss = lambda_mask_loss
|
136 |
-
self.rgb_loss = rgb_loss
|
137 |
-
|
138 |
-
self.sdf_network = SDFNetwork(d_out=257, d_in=3, d_hidden=256, n_layers=8, skip_in=[4], multires=6, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True)
|
139 |
-
self.color_network = RenderingNetwork(d_feature=256, d_in=9, d_out=3, d_hidden=256, n_layers=4, weight_norm=True, multires_view=4, squeeze_out=True)
|
140 |
-
self.default_dtype = torch.float32
|
141 |
-
self.deviation_network = SingleVarianceNetwork(0.3)
|
142 |
-
|
143 |
-
@torch.no_grad()
|
144 |
-
def get_vertex_colors(self, vertices):
|
145 |
-
"""
|
146 |
-
@param vertices: n,3
|
147 |
-
@return:
|
148 |
-
"""
|
149 |
-
V = vertices.shape[0]
|
150 |
-
bn = 20480
|
151 |
-
verts_colors = []
|
152 |
-
with torch.no_grad():
|
153 |
-
for vi in range(0, V, bn):
|
154 |
-
verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda()
|
155 |
-
feats = self.sdf_network(verts)[..., 1:]
|
156 |
-
gradients = self.sdf_network.gradient(verts) # ...,3
|
157 |
-
gradients = F.normalize(gradients, dim=-1)
|
158 |
-
colors = self.color_network(verts, gradients, gradients, feats)
|
159 |
-
colors = torch.clamp(colors,min=0,max=1).cpu().numpy()
|
160 |
-
verts_colors.append(colors)
|
161 |
-
|
162 |
-
verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8)
|
163 |
-
return verts_colors
|
164 |
-
|
165 |
-
def upsample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
|
166 |
-
"""
|
167 |
-
Up sampling give a fixed inv_s
|
168 |
-
"""
|
169 |
-
device = rays_o.device
|
170 |
-
batch_size, n_samples = z_vals.shape
|
171 |
-
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
172 |
-
inner_mask = self.get_inner_mask(pts)
|
173 |
-
# radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
174 |
-
inside_sphere = inner_mask[:, :-1] | inner_mask[:, 1:]
|
175 |
-
sdf = sdf.reshape(batch_size, n_samples)
|
176 |
-
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
177 |
-
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
178 |
-
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
179 |
-
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
180 |
-
|
181 |
-
prev_cos_val = torch.cat([torch.zeros([batch_size, 1], dtype=self.default_dtype, device=device), cos_val[:, :-1]], dim=-1)
|
182 |
-
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
183 |
-
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
184 |
-
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
185 |
-
|
186 |
-
dist = (next_z_vals - prev_z_vals)
|
187 |
-
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
188 |
-
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
189 |
-
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
190 |
-
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
191 |
-
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
192 |
-
weights = alpha * torch.cumprod(
|
193 |
-
torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
194 |
-
|
195 |
-
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
196 |
-
return z_samples
|
197 |
-
|
198 |
-
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
|
199 |
-
batch_size, n_samples = z_vals.shape
|
200 |
-
_, n_importance = new_z_vals.shape
|
201 |
-
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
202 |
-
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
203 |
-
z_vals, index = torch.sort(z_vals, dim=-1)
|
204 |
-
|
205 |
-
if not last:
|
206 |
-
device = pts.device
|
207 |
-
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
208 |
-
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
209 |
-
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1).to(device)
|
210 |
-
index = index.reshape(-1)
|
211 |
-
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
212 |
-
|
213 |
-
return z_vals, sdf
|
214 |
-
|
215 |
-
def sample_depth(self, rays_o, rays_d, near, far, perturb):
|
216 |
-
n_samples = self.n_samples
|
217 |
-
n_importance = self.n_importance
|
218 |
-
up_sample_steps = self.up_sample_steps
|
219 |
-
device = rays_o.device
|
220 |
-
|
221 |
-
# sample points
|
222 |
-
batch_size = len(rays_o)
|
223 |
-
z_vals = torch.linspace(0.0, 1.0, n_samples, dtype=self.default_dtype, device=device) # sn
|
224 |
-
z_vals = near + (far - near) * z_vals[None, :] # rn,sn
|
225 |
-
|
226 |
-
if perturb > 0:
|
227 |
-
t_rand = (torch.rand([batch_size, 1]).to(device) - 0.5)
|
228 |
-
z_vals = z_vals + t_rand * 2.0 / n_samples
|
229 |
-
|
230 |
-
# Up sample
|
231 |
-
with torch.no_grad():
|
232 |
-
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
233 |
-
sdf = self.sdf_network.sdf(pts).reshape(batch_size, n_samples)
|
234 |
-
|
235 |
-
for i in range(up_sample_steps):
|
236 |
-
rn, sn = z_vals.shape
|
237 |
-
inv_s = torch.ones(rn, sn - 1, dtype=self.default_dtype, device=device) * 64 * 2 ** i
|
238 |
-
new_z_vals = self.upsample(rays_o, rays_d, z_vals, sdf, n_importance // up_sample_steps, inv_s)
|
239 |
-
z_vals, sdf = self.cat_z_vals(rays_o, rays_d, z_vals, new_z_vals, sdf, last=(i + 1 == up_sample_steps))
|
240 |
-
|
241 |
-
return z_vals
|
242 |
-
|
243 |
-
def compute_sdf_alpha(self, points, dists, dirs, cos_anneal_ratio, step):
|
244 |
-
# points [...,3] dists [...] dirs[...,3]
|
245 |
-
sdf_nn_output = self.sdf_network(points)
|
246 |
-
sdf = sdf_nn_output[..., 0]
|
247 |
-
feature_vector = sdf_nn_output[..., 1:]
|
248 |
-
|
249 |
-
gradients = self.sdf_network.gradient(points) # ...,3
|
250 |
-
inv_s = self.deviation_network(points).clip(1e-6, 1e6) # ...,1
|
251 |
-
inv_s = inv_s[..., 0]
|
252 |
-
|
253 |
-
true_cos = (dirs * gradients).sum(-1) # [...]
|
254 |
-
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
255 |
-
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
256 |
-
|
257 |
-
# Estimate signed distances at section points
|
258 |
-
estimated_next_sdf = sdf + iter_cos * dists * 0.5
|
259 |
-
estimated_prev_sdf = sdf - iter_cos * dists * 0.5
|
260 |
-
|
261 |
-
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
262 |
-
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
263 |
-
|
264 |
-
p = prev_cdf - next_cdf
|
265 |
-
c = prev_cdf
|
266 |
-
|
267 |
-
alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) # [...]
|
268 |
-
return alpha, gradients, feature_vector, inv_s, sdf
|
269 |
-
|
270 |
-
def get_anneal_val(self, step):
|
271 |
-
if self.anneal_end < 0:
|
272 |
-
return 1.0
|
273 |
-
else:
|
274 |
-
return np.min([1.0, step / self.anneal_end])
|
275 |
-
|
276 |
-
def get_inner_mask(self, points):
|
277 |
-
return torch.sum(torch.abs(points)<=DEFAULT_SIDE_LENGTH,-1)==3
|
278 |
-
|
279 |
-
def render_impl(self, ray_batch, is_train, step):
|
280 |
-
near, far = near_far_from_sphere(ray_batch['rays_o'], ray_batch['rays_d'])
|
281 |
-
rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
|
282 |
-
z_vals = self.sample_depth(rays_o, rays_d, near, far, is_train)
|
283 |
-
|
284 |
-
batch_size, n_samples = z_vals.shape
|
285 |
-
|
286 |
-
# section length in original space
|
287 |
-
dists = z_vals[..., 1:] - z_vals[..., :-1] # rn,sn-1
|
288 |
-
dists = torch.cat([dists, dists[..., -1:]], -1) # rn,sn
|
289 |
-
mid_z_vals = z_vals + dists * 0.5
|
290 |
-
|
291 |
-
points = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * mid_z_vals.unsqueeze(-1) # rn, sn, 3
|
292 |
-
inner_mask = self.get_inner_mask(points)
|
293 |
-
|
294 |
-
dirs = rays_d.unsqueeze(-2).expand(batch_size, n_samples, 3)
|
295 |
-
dirs = F.normalize(dirs, dim=-1)
|
296 |
-
device = rays_o.device
|
297 |
-
alpha, sampled_color, gradient_error, normal = torch.zeros(batch_size, n_samples, dtype=self.default_dtype, device=device), \
|
298 |
-
torch.zeros(batch_size, n_samples, 3, dtype=self.default_dtype, device=device), \
|
299 |
-
torch.zeros([batch_size, n_samples], dtype=self.default_dtype, device=device), \
|
300 |
-
torch.zeros([batch_size, n_samples, 3], dtype=self.default_dtype, device=device)
|
301 |
-
if torch.sum(inner_mask) > 0:
|
302 |
-
cos_anneal_ratio = self.get_anneal_val(step) if is_train else 1.0
|
303 |
-
alpha[inner_mask], gradients, feature_vector, inv_s, sdf = self.compute_sdf_alpha(points[inner_mask], dists[inner_mask], dirs[inner_mask], cos_anneal_ratio, step)
|
304 |
-
sampled_color[inner_mask] = self.color_network(points[inner_mask], gradients, -dirs[inner_mask], feature_vector)
|
305 |
-
# Eikonal loss
|
306 |
-
gradient_error[inner_mask] = (torch.linalg.norm(gradients, ord=2, dim=-1) - 1.0) ** 2 # rn,sn
|
307 |
-
normal[inner_mask] = F.normalize(gradients, dim=-1)
|
308 |
-
|
309 |
-
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[..., :-1] # rn,sn
|
310 |
-
mask = torch.sum(weights,dim=1).unsqueeze(-1) # rn,1
|
311 |
-
color = (sampled_color * weights[..., None]).sum(dim=1) + (1 - mask) # add white background
|
312 |
-
normal = (normal * weights[..., None]).sum(dim=1)
|
313 |
-
|
314 |
-
outputs = {
|
315 |
-
'rgb': color, # rn,3
|
316 |
-
'gradient_error': gradient_error, # rn,sn
|
317 |
-
'inner_mask': inner_mask, # rn,sn
|
318 |
-
'normal': normal, # rn,3
|
319 |
-
'mask': mask, # rn,1
|
320 |
-
}
|
321 |
-
return outputs
|
322 |
-
|
323 |
-
def render_with_loss(self, ray_batch, is_train, step):
|
324 |
-
render_outputs = self.render(ray_batch, is_train, step)
|
325 |
-
|
326 |
-
rgb_gt = ray_batch['rgb']
|
327 |
-
rgb_pr = render_outputs['rgb']
|
328 |
-
if self.rgb_loss == 'soft_l1':
|
329 |
-
epsilon = 0.001
|
330 |
-
rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
|
331 |
-
elif self.rgb_loss =='mse':
|
332 |
-
rgb_loss = F.mse_loss(rgb_pr, rgb_gt, reduction='none')
|
333 |
-
else:
|
334 |
-
raise NotImplementedError
|
335 |
-
rgb_loss = torch.mean(rgb_loss)
|
336 |
-
|
337 |
-
eikonal_loss = torch.sum(render_outputs['gradient_error'] * render_outputs['inner_mask']) / torch.sum(render_outputs['inner_mask'] + 1e-5)
|
338 |
-
loss = rgb_loss * self.lambda_rgb_loss + eikonal_loss * self.lambda_eikonal_loss
|
339 |
-
loss_batch = {
|
340 |
-
'eikonal': eikonal_loss,
|
341 |
-
'rendering': rgb_loss,
|
342 |
-
# 'mask': mask_loss,
|
343 |
-
}
|
344 |
-
if self.lambda_mask_loss>0 and self.use_mask:
|
345 |
-
mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none').mean()
|
346 |
-
loss += mask_loss * self.lambda_mask_loss
|
347 |
-
loss_batch['mask'] = mask_loss
|
348 |
-
return loss, loss_batch
|
349 |
-
|
350 |
-
|
351 |
-
class NeRFRenderer(BaseRenderer):
|
352 |
-
def __init__(self, train_batch_num, test_batch_num, bound=0.5, use_mask=False, lambda_rgb_loss=1.0, lambda_mask_loss=0.0):
|
353 |
-
super().__init__(train_batch_num, test_batch_num)
|
354 |
-
self.train_batch_num = train_batch_num
|
355 |
-
self.test_batch_num = test_batch_num
|
356 |
-
self.use_mask = use_mask
|
357 |
-
self.field = NGPNetwork(bound=bound)
|
358 |
-
|
359 |
-
self.update_interval = 16
|
360 |
-
self.fp16 = True
|
361 |
-
self.lambda_rgb_loss = lambda_rgb_loss
|
362 |
-
self.lambda_mask_loss = lambda_mask_loss
|
363 |
-
|
364 |
-
def render_impl(self, ray_batch, is_train, step):
|
365 |
-
rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
|
366 |
-
with torch.cuda.amp.autocast(enabled=self.fp16):
|
367 |
-
if step % self.update_interval==0:
|
368 |
-
self.field.update_extra_state()
|
369 |
-
|
370 |
-
outputs = self.field.render(rays_o, rays_d,)
|
371 |
-
|
372 |
-
renderings={
|
373 |
-
'rgb': outputs['image'],
|
374 |
-
'depth': outputs['depth'],
|
375 |
-
'mask': outputs['weights_sum'].unsqueeze(-1),
|
376 |
-
}
|
377 |
-
return renderings
|
378 |
-
|
379 |
-
def render_with_loss(self, ray_batch, is_train, step):
|
380 |
-
render_outputs = self.render(ray_batch, is_train, step)
|
381 |
-
|
382 |
-
rgb_gt = ray_batch['rgb']
|
383 |
-
rgb_pr = render_outputs['rgb']
|
384 |
-
epsilon = 0.001
|
385 |
-
rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
|
386 |
-
rgb_loss = torch.mean(rgb_loss)
|
387 |
-
loss = rgb_loss * self.lambda_rgb_loss
|
388 |
-
loss_batch = {'rendering': rgb_loss}
|
389 |
-
|
390 |
-
if self.use_mask:
|
391 |
-
mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none')
|
392 |
-
mask_loss = torch.mean(mask_loss)
|
393 |
-
loss = loss + mask_loss * self.lambda_mask_loss
|
394 |
-
loss_batch['mask'] = mask_loss
|
395 |
-
return loss, loss_batch
|
396 |
-
|
397 |
-
|
398 |
-
class RendererTrainer(pl.LightningModule):
|
399 |
-
def __init__(self, image_path, total_steps, warm_up_steps, log_dir, train_batch_fg_num=0,
|
400 |
-
use_cube_feats=False, cube_ckpt=None, cube_cfg=None, cube_bound=0.5,
|
401 |
-
train_batch_num=4096, test_batch_num=8192, use_warm_up=True, use_mask=True,
|
402 |
-
lambda_rgb_loss=1.0, lambda_mask_loss=0.0, renderer='neus',
|
403 |
-
# used in neus
|
404 |
-
lambda_eikonal_loss=0.1,
|
405 |
-
coarse_sn=64, fine_sn=64):
|
406 |
-
super().__init__()
|
407 |
-
self.num_images = 16
|
408 |
-
self.image_size = 256
|
409 |
-
self.log_dir = log_dir
|
410 |
-
(Path(log_dir)/'images').mkdir(exist_ok=True, parents=True)
|
411 |
-
self.train_batch_num = train_batch_num
|
412 |
-
self.train_batch_fg_num = train_batch_fg_num
|
413 |
-
self.test_batch_num = test_batch_num
|
414 |
-
self.image_path = image_path
|
415 |
-
self.total_steps = total_steps
|
416 |
-
self.warm_up_steps = warm_up_steps
|
417 |
-
self.use_mask = use_mask
|
418 |
-
self.lambda_eikonal_loss = lambda_eikonal_loss
|
419 |
-
self.lambda_rgb_loss = lambda_rgb_loss
|
420 |
-
self.lambda_mask_loss = lambda_mask_loss
|
421 |
-
self.use_warm_up = use_warm_up
|
422 |
-
|
423 |
-
self.use_cube_feats, self.cube_cfg, self.cube_ckpt = use_cube_feats, cube_cfg, cube_ckpt
|
424 |
-
|
425 |
-
self._init_dataset()
|
426 |
-
if renderer=='neus':
|
427 |
-
self.renderer = NeuSRenderer(train_batch_num, test_batch_num,
|
428 |
-
lambda_rgb_loss=lambda_rgb_loss,
|
429 |
-
lambda_eikonal_loss=lambda_eikonal_loss,
|
430 |
-
lambda_mask_loss=lambda_mask_loss,
|
431 |
-
coarse_sn=coarse_sn, fine_sn=fine_sn)
|
432 |
-
elif renderer=='ngp':
|
433 |
-
self.renderer = NeRFRenderer(train_batch_num, test_batch_num, bound=cube_bound, use_mask=use_mask, lambda_mask_loss=lambda_mask_loss, lambda_rgb_loss=lambda_rgb_loss,)
|
434 |
-
else:
|
435 |
-
raise NotImplementedError
|
436 |
-
self.validation_index = 0
|
437 |
-
|
438 |
-
def _construct_ray_batch(self, images_info):
|
439 |
-
image_num = images_info['images'].shape[0]
|
440 |
-
_, h, w, _ = images_info['images'].shape
|
441 |
-
coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
|
442 |
-
coords = coords.float()[None, :, :, :].repeat(image_num, 1, 1, 1) # imn,h,w,2
|
443 |
-
coords = coords.reshape(image_num, h * w, 2)
|
444 |
-
coords = torch.cat([coords, torch.ones(image_num, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
|
445 |
-
|
446 |
-
# imn,h*w,3 @ imn,3,3 => imn,h*w,3
|
447 |
-
rays_d = coords @ torch.inverse(images_info['Ks']).permute(0, 2, 1)
|
448 |
-
poses = images_info['poses'] # imn,3,4
|
449 |
-
R, t = poses[:, :, :3], poses[:, :, 3:]
|
450 |
-
rays_d = rays_d @ R
|
451 |
-
rays_d = F.normalize(rays_d, dim=-1)
|
452 |
-
rays_o = -R.permute(0,2,1) @ t # imn,3,3 @ imn,3,1
|
453 |
-
rays_o = rays_o.permute(0, 2, 1).repeat(1, h*w, 1) # imn,h*w,3
|
454 |
-
|
455 |
-
ray_batch = {
|
456 |
-
'rgb': images_info['images'].reshape(image_num*h*w,3),
|
457 |
-
'mask': images_info['masks'].reshape(image_num*h*w,1),
|
458 |
-
'rays_o': rays_o.reshape(image_num*h*w,3).float(),
|
459 |
-
'rays_d': rays_d.reshape(image_num*h*w,3).float(),
|
460 |
-
}
|
461 |
-
return ray_batch
|
462 |
-
|
463 |
-
@staticmethod
|
464 |
-
def load_model(cfg, ckpt):
|
465 |
-
config = OmegaConf.load(cfg)
|
466 |
-
model = instantiate_from_config(config.model)
|
467 |
-
print(f'loading model from {ckpt} ...')
|
468 |
-
ckpt = torch.load(ckpt)
|
469 |
-
model.load_state_dict(ckpt['state_dict'])
|
470 |
-
model = model.cuda().eval()
|
471 |
-
return model
|
472 |
-
|
473 |
-
def _init_dataset(self):
|
474 |
-
mask_predictor = BackgroundRemoval()
|
475 |
-
self.K, self.azs, self.els, self.dists, self.poses = read_pickle(f'meta_info/camera-{self.num_images}.pkl')
|
476 |
-
|
477 |
-
self.images_info = {'images': [] ,'masks': [], 'Ks': [], 'poses':[]}
|
478 |
-
|
479 |
-
img = imread(self.image_path)
|
480 |
-
|
481 |
-
for index in range(self.num_images):
|
482 |
-
rgb = np.copy(img[:,index*self.image_size:(index+1)*self.image_size,:])
|
483 |
-
# predict mask
|
484 |
-
if self.use_mask:
|
485 |
-
imsave(f'{self.log_dir}/input-{index}.png', rgb)
|
486 |
-
masked_image = mask_predictor(rgb)
|
487 |
-
imsave(f'{self.log_dir}/masked-{index}.png', masked_image)
|
488 |
-
mask = masked_image[:,:,3].astype(np.float32)/255
|
489 |
-
else:
|
490 |
-
h, w, _ = rgb.shape
|
491 |
-
mask = np.zeros([h,w], np.float32)
|
492 |
-
|
493 |
-
rgb = rgb.astype(np.float32)/255
|
494 |
-
K, pose = np.copy(self.K), self.poses[index]
|
495 |
-
self.images_info['images'].append(torch.from_numpy(rgb.astype(np.float32))) # h,w,3
|
496 |
-
self.images_info['masks'].append(torch.from_numpy(mask.astype(np.float32))) # h,w
|
497 |
-
self.images_info['Ks'].append(torch.from_numpy(K.astype(np.float32)))
|
498 |
-
self.images_info['poses'].append(torch.from_numpy(pose.astype(np.float32)))
|
499 |
-
|
500 |
-
for k, v in self.images_info.items(): self.images_info[k] = torch.stack(v, 0) # stack all values
|
501 |
-
|
502 |
-
self.train_batch = self._construct_ray_batch(self.images_info)
|
503 |
-
self.train_batch_pseudo_fg = {}
|
504 |
-
pseudo_fg_mask = torch.sum(self.train_batch['rgb']>0.99,1)!=3
|
505 |
-
for k, v in self.train_batch.items():
|
506 |
-
self.train_batch_pseudo_fg[k] = v[pseudo_fg_mask]
|
507 |
-
self.train_ray_fg_num = int(torch.sum(pseudo_fg_mask).cpu().numpy())
|
508 |
-
self.train_ray_num = self.num_images * self.image_size ** 2
|
509 |
-
self._shuffle_train_batch()
|
510 |
-
self._shuffle_train_fg_batch()
|
511 |
-
|
512 |
-
def _shuffle_train_batch(self):
|
513 |
-
self.train_batch_i = 0
|
514 |
-
shuffle_idxs = torch.randperm(self.train_ray_num, device='cpu') # shuffle
|
515 |
-
for k, v in self.train_batch.items():
|
516 |
-
self.train_batch[k] = v[shuffle_idxs]
|
517 |
-
|
518 |
-
def _shuffle_train_fg_batch(self):
|
519 |
-
self.train_batch_fg_i = 0
|
520 |
-
shuffle_idxs = torch.randperm(self.train_ray_fg_num, device='cpu') # shuffle
|
521 |
-
for k, v in self.train_batch_pseudo_fg.items():
|
522 |
-
self.train_batch_pseudo_fg[k] = v[shuffle_idxs]
|
523 |
-
|
524 |
-
|
525 |
-
def training_step(self, batch, batch_idx):
|
526 |
-
train_ray_batch = {k: v[self.train_batch_i:self.train_batch_i + self.train_batch_num].cuda() for k, v in self.train_batch.items()}
|
527 |
-
self.train_batch_i += self.train_batch_num
|
528 |
-
if self.train_batch_i + self.train_batch_num >= self.train_ray_num: self._shuffle_train_batch()
|
529 |
-
|
530 |
-
if self.train_batch_fg_num>0:
|
531 |
-
train_ray_batch_fg = {k: v[self.train_batch_fg_i:self.train_batch_fg_i+self.train_batch_fg_num].cuda() for k, v in self.train_batch_pseudo_fg.items()}
|
532 |
-
self.train_batch_fg_i += self.train_batch_fg_num
|
533 |
-
if self.train_batch_fg_i + self.train_batch_fg_num >= self.train_ray_fg_num: self._shuffle_train_fg_batch()
|
534 |
-
for k, v in train_ray_batch_fg.items():
|
535 |
-
train_ray_batch[k] = torch.cat([train_ray_batch[k], v], 0)
|
536 |
-
|
537 |
-
loss, loss_batch = self.renderer.render_with_loss(train_ray_batch, is_train=True, step=self.global_step)
|
538 |
-
self.log_dict(loss_batch, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
|
539 |
-
|
540 |
-
self.log('step', self.global_step, prog_bar=True, on_step=True, on_epoch=False, logger=False, rank_zero_only=True)
|
541 |
-
lr = self.optimizers().param_groups[0]['lr']
|
542 |
-
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
|
543 |
-
return loss
|
544 |
-
|
545 |
-
def _slice_images_info(self, index):
|
546 |
-
return {k:v[index:index+1] for k, v in self.images_info.items()}
|
547 |
-
|
548 |
-
@torch.no_grad()
|
549 |
-
def validation_step(self, batch, batch_idx):
|
550 |
-
with torch.no_grad():
|
551 |
-
if self.global_rank==0:
|
552 |
-
# we output an rendering image
|
553 |
-
images_info = self._slice_images_info(self.validation_index)
|
554 |
-
self.validation_index += 1
|
555 |
-
self.validation_index %= self.num_images
|
556 |
-
|
557 |
-
test_ray_batch = self._construct_ray_batch(images_info)
|
558 |
-
test_ray_batch = {k: v.cuda() for k,v in test_ray_batch.items()}
|
559 |
-
test_ray_batch['near'], test_ray_batch['far'] = near_far_from_sphere(test_ray_batch['rays_o'], test_ray_batch['rays_d'])
|
560 |
-
render_outputs = self.renderer.render(test_ray_batch, False, self.global_step)
|
561 |
-
|
562 |
-
process = lambda x: (x.cpu().numpy() * 255).astype(np.uint8)
|
563 |
-
h, w = self.image_size, self.image_size
|
564 |
-
rgb = torch.clamp(render_outputs['rgb'].reshape(h, w, 3), max=1.0, min=0.0)
|
565 |
-
mask = torch.clamp(render_outputs['mask'].reshape(h, w, 1), max=1.0, min=0.0)
|
566 |
-
mask_ = torch.repeat_interleave(mask, 3, dim=-1)
|
567 |
-
output_image = concat_images_list(process(rgb), process(mask_))
|
568 |
-
if 'normal' in render_outputs:
|
569 |
-
normal = torch.clamp((render_outputs['normal'].reshape(h, w, 3) + 1) / 2, max=1.0, min=0.0)
|
570 |
-
normal = normal * mask # we only show foregound normal
|
571 |
-
output_image = concat_images_list(output_image, process(normal))
|
572 |
-
|
573 |
-
# save images
|
574 |
-
imsave(f'{self.log_dir}/images/{self.global_step}.jpg', output_image)
|
575 |
-
|
576 |
-
def configure_optimizers(self):
|
577 |
-
lr = self.learning_rate
|
578 |
-
opt = torch.optim.AdamW([{"params": self.renderer.parameters(), "lr": lr},], lr=lr)
|
579 |
-
|
580 |
-
def schedule_fn(step):
|
581 |
-
total_step = self.total_steps
|
582 |
-
warm_up_step = self.warm_up_steps
|
583 |
-
warm_up_init = 0.02
|
584 |
-
warm_up_end = 1.0
|
585 |
-
final_lr = 0.02
|
586 |
-
interval = 1000
|
587 |
-
times = total_step // interval
|
588 |
-
ratio = np.power(final_lr, 1/times)
|
589 |
-
if step<warm_up_step:
|
590 |
-
learning_rate = (step / warm_up_step) * (warm_up_end - warm_up_init) + warm_up_init
|
591 |
-
else:
|
592 |
-
learning_rate = ratio ** (step // interval) * warm_up_end
|
593 |
-
return learning_rate
|
594 |
-
|
595 |
-
if self.use_warm_up:
|
596 |
-
scheduler = [{
|
597 |
-
'scheduler': LambdaLR(opt, lr_lambda=schedule_fn),
|
598 |
-
'interval': 'step',
|
599 |
-
'frequency': 1
|
600 |
-
}]
|
601 |
-
else:
|
602 |
-
scheduler = []
|
603 |
-
return [opt], scheduler
|
604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -19,5 +19,4 @@ trimesh
|
|
19 |
easydict
|
20 |
nerfacc
|
21 |
imageio-ffmpeg==0.4.7
|
22 |
-
git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
|
23 |
git+https://github.com/openai/CLIP.git
|
|
|
19 |
easydict
|
20 |
nerfacc
|
21 |
imageio-ffmpeg==0.4.7
|
|
|
22 |
git+https://github.com/openai/CLIP.git
|
train_renderer.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
|
3 |
-
import imageio
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
import trimesh
|
10 |
-
from omegaconf import OmegaConf
|
11 |
-
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback
|
12 |
-
from pytorch_lightning.loggers import TensorBoardLogger
|
13 |
-
from pytorch_lightning import Trainer
|
14 |
-
from skimage.io import imsave
|
15 |
-
from tqdm import tqdm
|
16 |
-
|
17 |
-
import mcubes
|
18 |
-
|
19 |
-
from ldm.base_utils import read_pickle, output_points
|
20 |
-
from renderer.renderer import NeuSRenderer, DEFAULT_SIDE_LENGTH
|
21 |
-
from ldm.util import instantiate_from_config
|
22 |
-
|
23 |
-
class ResumeCallBacks(Callback):
|
24 |
-
def __init__(self):
|
25 |
-
pass
|
26 |
-
|
27 |
-
def on_train_start(self, trainer, pl_module):
|
28 |
-
pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups
|
29 |
-
|
30 |
-
def render_images(model, output,):
|
31 |
-
# render from model
|
32 |
-
n = 180
|
33 |
-
azimuths = (np.arange(n) / n * np.pi * 2).astype(np.float32)
|
34 |
-
elevations = np.deg2rad(np.asarray([30] * n).astype(np.float32))
|
35 |
-
K, _, _, _, poses = read_pickle(f'meta_info/camera-16.pkl')
|
36 |
-
output_points
|
37 |
-
h, w = 256, 256
|
38 |
-
default_size = 256
|
39 |
-
K = np.diag([w/default_size,h/default_size,1.0]) @ K
|
40 |
-
imgs = []
|
41 |
-
for ni in tqdm(range(n)):
|
42 |
-
# R = euler2mat(azimuths[ni], elevations[ni], 0, 'szyx')
|
43 |
-
# R = np.asarray([[0,-1,0],[0,0,-1],[1,0,0]]) @ R
|
44 |
-
e, a = elevations[ni], azimuths[ni]
|
45 |
-
row1 = np.asarray([np.sin(e)*np.cos(a),np.sin(e)*np.sin(a),-np.cos(e)])
|
46 |
-
row0 = np.asarray([-np.sin(a),np.cos(a), 0])
|
47 |
-
row2 = np.cross(row0, row1)
|
48 |
-
R = np.stack([row0,row1,row2],0)
|
49 |
-
t = np.asarray([0,0,1.5])
|
50 |
-
pose = np.concatenate([R,t[:,None]],1)
|
51 |
-
pose_ = torch.from_numpy(pose.astype(np.float32)).unsqueeze(0)
|
52 |
-
K_ = torch.from_numpy(K.astype(np.float32)).unsqueeze(0) # [1,3,3]
|
53 |
-
|
54 |
-
coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
|
55 |
-
coords = coords.float()[None, :, :, :].repeat(1, 1, 1, 1) # imn,h,w,2
|
56 |
-
coords = coords.reshape(1, h * w, 2)
|
57 |
-
coords = torch.cat([coords, torch.ones(1, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
|
58 |
-
|
59 |
-
# imn,h*w,3 @ imn,3,3 => imn,h*w,3
|
60 |
-
rays_d = coords @ torch.inverse(K_).permute(0, 2, 1)
|
61 |
-
R, t = pose_[:, :, :3], pose_[:, :, 3:]
|
62 |
-
rays_d = rays_d @ R
|
63 |
-
rays_d = F.normalize(rays_d, dim=-1)
|
64 |
-
rays_o = -R.permute(0, 2, 1) @ t # imn,3,3 @ imn,3,1
|
65 |
-
rays_o = rays_o.permute(0, 2, 1).repeat(1, h * w, 1) # imn,h*w,3
|
66 |
-
|
67 |
-
ray_batch = {
|
68 |
-
'rays_o': rays_o.reshape(-1,3).cuda(),
|
69 |
-
'rays_d': rays_d.reshape(-1,3).cuda(),
|
70 |
-
}
|
71 |
-
with torch.no_grad():
|
72 |
-
image = model.renderer.render(ray_batch,False,5000)['rgb'].reshape(h,w,3)
|
73 |
-
image = (image.cpu().numpy() * 255).astype(np.uint8)
|
74 |
-
imgs.append(image)
|
75 |
-
|
76 |
-
imageio.mimsave(f'{output}/rendering.mp4', imgs, fps=30)
|
77 |
-
|
78 |
-
def extract_fields(bound_min, bound_max, resolution, query_func, batch_size=64, outside_val=1.0):
|
79 |
-
N = batch_size
|
80 |
-
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
81 |
-
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
82 |
-
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
83 |
-
|
84 |
-
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
85 |
-
with torch.no_grad():
|
86 |
-
for xi, xs in enumerate(X):
|
87 |
-
for yi, ys in enumerate(Y):
|
88 |
-
for zi, zs in enumerate(Z):
|
89 |
-
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
90 |
-
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda()
|
91 |
-
val = query_func(pts).detach()
|
92 |
-
outside_mask = torch.norm(pts,dim=-1)>=1.0
|
93 |
-
val[outside_mask]=outside_val
|
94 |
-
val = val.reshape(len(xs), len(ys), len(zs)).cpu().numpy()
|
95 |
-
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
96 |
-
return u
|
97 |
-
|
98 |
-
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func, outside_val=1.0):
|
99 |
-
u = extract_fields(bound_min, bound_max, resolution, query_func, outside_val=outside_val)
|
100 |
-
vertices, triangles = mcubes.marching_cubes(u, threshold)
|
101 |
-
b_max_np = bound_max.detach().cpu().numpy()
|
102 |
-
b_min_np = bound_min.detach().cpu().numpy()
|
103 |
-
|
104 |
-
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
105 |
-
vertex_colors = color_func(vertices)
|
106 |
-
return vertices, triangles, vertex_colors
|
107 |
-
|
108 |
-
def extract_mesh(model, output, resolution=512):
|
109 |
-
if not isinstance(model.renderer, NeuSRenderer): return
|
110 |
-
bbox_min = -torch.ones(3)*DEFAULT_SIDE_LENGTH
|
111 |
-
bbox_max = torch.ones(3)*DEFAULT_SIDE_LENGTH
|
112 |
-
with torch.no_grad():
|
113 |
-
vertices, triangles, vertex_colors = extract_geometry(bbox_min, bbox_max, resolution, 0, lambda x: model.renderer.sdf_network.sdf(x), lambda x: model.renderer.get_vertex_colors(x))
|
114 |
-
|
115 |
-
# output geometry
|
116 |
-
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors)
|
117 |
-
mesh.export(str(f'{output}/mesh.ply'))
|
118 |
-
|
119 |
-
def main():
|
120 |
-
parser = argparse.ArgumentParser()
|
121 |
-
parser.add_argument('-i', '--image_path', type=str, required=True)
|
122 |
-
parser.add_argument('-n', '--name', type=str, required=True)
|
123 |
-
parser.add_argument('-b', '--base', type=str, default='configs/neus.yaml')
|
124 |
-
parser.add_argument('-l', '--log', type=str, default='output/renderer')
|
125 |
-
parser.add_argument('-s', '--seed', type=int, default=6033)
|
126 |
-
parser.add_argument('-g', '--gpus', type=str, default='0,')
|
127 |
-
parser.add_argument('-r', '--resume', action='store_true', default=False, dest='resume')
|
128 |
-
parser.add_argument('--fp16', action='store_true', default=False, dest='fp16')
|
129 |
-
opt = parser.parse_args()
|
130 |
-
# seed_everything(opt.seed)
|
131 |
-
|
132 |
-
# configs
|
133 |
-
cfg = OmegaConf.load(opt.base)
|
134 |
-
name = opt.name
|
135 |
-
log_dir, ckpt_dir = Path(opt.log) / name, Path(opt.log) / name / 'ckpt'
|
136 |
-
cfg.model.params['image_path'] = opt.image_path
|
137 |
-
cfg.model.params['log_dir'] = log_dir
|
138 |
-
|
139 |
-
# setup
|
140 |
-
log_dir.mkdir(exist_ok=True, parents=True)
|
141 |
-
ckpt_dir.mkdir(exist_ok=True, parents=True)
|
142 |
-
trainer_config = cfg.trainer
|
143 |
-
callback_config = cfg.callbacks
|
144 |
-
model_config = cfg.model
|
145 |
-
data_config = cfg.data
|
146 |
-
|
147 |
-
data_config.params.seed = opt.seed
|
148 |
-
data = instantiate_from_config(data_config)
|
149 |
-
data.prepare_data()
|
150 |
-
data.setup('fit')
|
151 |
-
|
152 |
-
model = instantiate_from_config(model_config,)
|
153 |
-
model.cpu()
|
154 |
-
model.learning_rate = model_config.base_lr
|
155 |
-
|
156 |
-
# logger
|
157 |
-
logger = TensorBoardLogger(save_dir=log_dir, name='tensorboard_logs')
|
158 |
-
callbacks=[]
|
159 |
-
callbacks.append(LearningRateMonitor(logging_interval='step'))
|
160 |
-
callbacks.append(ModelCheckpoint(dirpath=ckpt_dir, filename="{epoch:06}", verbose=True, save_last=True, every_n_train_steps=callback_config.save_interval))
|
161 |
-
|
162 |
-
# trainer
|
163 |
-
trainer_config.update({
|
164 |
-
"accelerator": "cuda", "check_val_every_n_epoch": None,
|
165 |
-
"benchmark": True, "num_sanity_val_steps": 0,
|
166 |
-
"devices": 1, "gpus": opt.gpus,
|
167 |
-
})
|
168 |
-
if opt.fp16:
|
169 |
-
trainer_config['precision']=16
|
170 |
-
|
171 |
-
if opt.resume:
|
172 |
-
callbacks.append(ResumeCallBacks())
|
173 |
-
trainer_config['resume_from_checkpoint'] = str(ckpt_dir / 'last.ckpt')
|
174 |
-
else:
|
175 |
-
if (ckpt_dir / 'last.ckpt').exists():
|
176 |
-
raise RuntimeError(f"checkpoint {ckpt_dir / 'last.ckpt'} existing ...")
|
177 |
-
trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config, logger=logger, callbacks=callbacks)
|
178 |
-
|
179 |
-
trainer.fit(model, data)
|
180 |
-
|
181 |
-
model = model.cuda().eval()
|
182 |
-
|
183 |
-
render_images(model, log_dir)
|
184 |
-
extract_mesh(model, log_dir)
|
185 |
-
|
186 |
-
if __name__=="__main__":
|
187 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_syncdreamer.py
DELETED
@@ -1,307 +0,0 @@
|
|
1 |
-
import argparse, os, sys
|
2 |
-
import numpy as np
|
3 |
-
import time
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torchvision
|
7 |
-
import pytorch_lightning as pl
|
8 |
-
|
9 |
-
from omegaconf import OmegaConf
|
10 |
-
from PIL import Image
|
11 |
-
|
12 |
-
from pytorch_lightning import seed_everything
|
13 |
-
from pytorch_lightning.strategies import DDPStrategy
|
14 |
-
from pytorch_lightning.trainer import Trainer
|
15 |
-
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
16 |
-
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
17 |
-
|
18 |
-
from ldm.util import instantiate_from_config
|
19 |
-
|
20 |
-
|
21 |
-
@rank_zero_only
|
22 |
-
def rank_zero_print(*args):
|
23 |
-
print(*args)
|
24 |
-
|
25 |
-
def get_parser(**parser_kwargs):
|
26 |
-
def str2bool(v):
|
27 |
-
if isinstance(v, bool):
|
28 |
-
return v
|
29 |
-
if v.lower() in ("yes", "true", "t", "y", "1"):
|
30 |
-
return True
|
31 |
-
elif v.lower() in ("no", "false", "f", "n", "0"):
|
32 |
-
return False
|
33 |
-
else:
|
34 |
-
raise argparse.ArgumentTypeError("Boolean value expected.")
|
35 |
-
|
36 |
-
parser = argparse.ArgumentParser(**parser_kwargs)
|
37 |
-
parser.add_argument("-r", "--resume", dest='resume', action='store_true', default=False)
|
38 |
-
parser.add_argument("-b", "--base", type=str, default='configs/syncdreamer-training.yaml',)
|
39 |
-
parser.add_argument("-l", "--logdir", type=str, default="ckpt/logs", help="directory for logging data", )
|
40 |
-
parser.add_argument("-c", "--ckptdir", type=str, default="ckpt/models", help="directory for checkpoint data", )
|
41 |
-
parser.add_argument("-s", "--seed", type=int, default=6033, help="seed for seed_everything", )
|
42 |
-
parser.add_argument("--finetune_from", type=str, default="/cfs-cq-dcc/rondyliu/models/sd-image-conditioned-v2.ckpt", help="path to checkpoint to load model state from" )
|
43 |
-
parser.add_argument("--gpus", type=str, default='0,')
|
44 |
-
return parser
|
45 |
-
|
46 |
-
def trainer_args(opt):
|
47 |
-
parser = argparse.ArgumentParser()
|
48 |
-
parser = Trainer.add_argparse_args(parser)
|
49 |
-
args = parser.parse_args([])
|
50 |
-
return sorted(k for k in vars(args) if hasattr(opt, k))
|
51 |
-
|
52 |
-
class SetupCallback(Callback):
|
53 |
-
def __init__(self, resume, logdir, ckptdir, cfgdir, config):
|
54 |
-
super().__init__()
|
55 |
-
self.resume = resume
|
56 |
-
self.logdir = logdir
|
57 |
-
self.ckptdir = ckptdir
|
58 |
-
self.cfgdir = cfgdir
|
59 |
-
self.config = config
|
60 |
-
|
61 |
-
def on_fit_start(self, trainer, pl_module):
|
62 |
-
if trainer.global_rank == 0:
|
63 |
-
# Create logdirs and save configs
|
64 |
-
os.makedirs(self.logdir, exist_ok=True)
|
65 |
-
os.makedirs(self.ckptdir, exist_ok=True)
|
66 |
-
os.makedirs(self.cfgdir, exist_ok=True)
|
67 |
-
|
68 |
-
rank_zero_print(OmegaConf.to_yaml(self.config))
|
69 |
-
OmegaConf.save(self.config, os.path.join(self.cfgdir, "configs.yaml"))
|
70 |
-
|
71 |
-
if not self.resume and os.path.exists(os.path.join(self.logdir,'checkpoints','last.ckpt')):
|
72 |
-
raise RuntimeError(f"checkpoint {os.path.join(self.logdir,'checkpoints','last.ckpt')} existing")
|
73 |
-
|
74 |
-
class ImageLogger(Callback):
|
75 |
-
def __init__(self, batch_frequency, max_images, log_images_kwargs=None):
|
76 |
-
super().__init__()
|
77 |
-
self.batch_freq = batch_frequency
|
78 |
-
self.max_images = max_images
|
79 |
-
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
80 |
-
|
81 |
-
@rank_zero_only
|
82 |
-
def log_to_logger(self, pl_module, images, split):
|
83 |
-
for k in images:
|
84 |
-
grid = torchvision.utils.make_grid(images[k])
|
85 |
-
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
86 |
-
|
87 |
-
tag = f"{split}/{k}"
|
88 |
-
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
|
89 |
-
|
90 |
-
@rank_zero_only
|
91 |
-
def log_to_file(self, save_dir, split, images, global_step, current_epoch):
|
92 |
-
root = os.path.join(save_dir, "images", split)
|
93 |
-
for k in images:
|
94 |
-
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
95 |
-
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
96 |
-
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
97 |
-
grid = grid.numpy()
|
98 |
-
grid = (grid * 255).astype(np.uint8)
|
99 |
-
filename = "{:06}-{:06}-{}.jpg".format(global_step, current_epoch, k)
|
100 |
-
path = os.path.join(root, filename)
|
101 |
-
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
102 |
-
Image.fromarray(grid).save(path)
|
103 |
-
|
104 |
-
@rank_zero_only
|
105 |
-
def log_img(self, pl_module, batch, split="train"):
|
106 |
-
if split == "val": should_log = True
|
107 |
-
else: should_log = self.check_frequency(pl_module.global_step)
|
108 |
-
|
109 |
-
if should_log:
|
110 |
-
is_train = pl_module.training
|
111 |
-
if is_train: pl_module.eval()
|
112 |
-
|
113 |
-
with torch.no_grad():
|
114 |
-
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
115 |
-
|
116 |
-
for k in images:
|
117 |
-
N = min(images[k].shape[0], self.max_images)
|
118 |
-
images[k] = images[k][:N]
|
119 |
-
if isinstance(images[k], torch.Tensor):
|
120 |
-
images[k] = images[k].detach().cpu()
|
121 |
-
images[k] = torch.clamp(images[k], -1., 1.)
|
122 |
-
|
123 |
-
self.log_to_file(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch)
|
124 |
-
# self.log_to_logger(pl_module, images, split)
|
125 |
-
|
126 |
-
if is_train: pl_module.train()
|
127 |
-
|
128 |
-
def check_frequency(self, check_idx):
|
129 |
-
if (check_idx % self.batch_freq) == 0 and check_idx > 0:
|
130 |
-
return True
|
131 |
-
else:
|
132 |
-
return False
|
133 |
-
|
134 |
-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
135 |
-
self.log_img(pl_module, batch, split="train")
|
136 |
-
|
137 |
-
@rank_zero_only
|
138 |
-
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
|
139 |
-
# print('validation ....')
|
140 |
-
# print(dataloader_idx)
|
141 |
-
# print(batch_idx)
|
142 |
-
if batch_idx==0: self.log_img(pl_module, batch, split="val")
|
143 |
-
|
144 |
-
class CUDACallback(Callback):
|
145 |
-
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
146 |
-
def on_train_epoch_start(self, trainer, pl_module):
|
147 |
-
# Reset the memory use counter
|
148 |
-
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
|
149 |
-
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
150 |
-
self.start_time = time.time()
|
151 |
-
|
152 |
-
def on_train_epoch_end(self, trainer, pl_module):
|
153 |
-
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
154 |
-
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
|
155 |
-
epoch_time = time.time() - self.start_time
|
156 |
-
|
157 |
-
try:
|
158 |
-
max_memory = trainer.strategy.reduce(max_memory)
|
159 |
-
epoch_time = trainer.strategy.reduce(epoch_time)
|
160 |
-
|
161 |
-
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
162 |
-
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
163 |
-
except AttributeError:
|
164 |
-
pass
|
165 |
-
|
166 |
-
def get_node_name(name, parent_name):
|
167 |
-
if len(name) <= len(parent_name):
|
168 |
-
return False, ''
|
169 |
-
p = name[:len(parent_name)]
|
170 |
-
if p != parent_name:
|
171 |
-
return False, ''
|
172 |
-
return True, name[len(parent_name):]
|
173 |
-
|
174 |
-
class ResumeCallBacks(Callback):
|
175 |
-
def on_train_start(self, trainer, pl_module):
|
176 |
-
pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups
|
177 |
-
|
178 |
-
def load_pretrain_stable_diffusion(new_model, finetune_from):
|
179 |
-
rank_zero_print(f"Attempting to load state from {finetune_from}")
|
180 |
-
old_state = torch.load(finetune_from, map_location="cpu")
|
181 |
-
if "state_dict" in old_state: old_state = old_state["state_dict"]
|
182 |
-
|
183 |
-
in_filters_load = old_state["model.diffusion_model.input_blocks.0.0.weight"]
|
184 |
-
new_state = new_model.state_dict()
|
185 |
-
if "model.diffusion_model.input_blocks.0.0.weight" in new_state:
|
186 |
-
in_filters_current = new_state["model.diffusion_model.input_blocks.0.0.weight"]
|
187 |
-
in_shape = in_filters_current.shape
|
188 |
-
## because the model adopts additional inputs as conditions.
|
189 |
-
if in_shape != in_filters_load.shape:
|
190 |
-
input_keys = ["model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight",]
|
191 |
-
for input_key in input_keys:
|
192 |
-
if input_key not in old_state or input_key not in new_state:
|
193 |
-
continue
|
194 |
-
input_weight = new_state[input_key]
|
195 |
-
if input_weight.size() != old_state[input_key].size():
|
196 |
-
print(f"Manual init: {input_key}")
|
197 |
-
input_weight.zero_()
|
198 |
-
input_weight[:, :4, :, :].copy_(old_state[input_key])
|
199 |
-
old_state[input_key] = torch.nn.parameter.Parameter(input_weight)
|
200 |
-
|
201 |
-
new_model.load_state_dict(old_state, strict=False)
|
202 |
-
|
203 |
-
def get_optional_dict(name, config):
|
204 |
-
if name in config:
|
205 |
-
cfg = config[name]
|
206 |
-
else:
|
207 |
-
cfg = OmegaConf.create()
|
208 |
-
return cfg
|
209 |
-
|
210 |
-
if __name__ == "__main__":
|
211 |
-
# now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
212 |
-
sys.path.append(os.getcwd())
|
213 |
-
opt = get_parser().parse_args()
|
214 |
-
|
215 |
-
assert opt.base != ''
|
216 |
-
name = os.path.split(opt.base)[-1]
|
217 |
-
name = os.path.splitext(name)[0]
|
218 |
-
logdir = os.path.join(opt.logdir, name)
|
219 |
-
|
220 |
-
# logdir: checkpoints+configs
|
221 |
-
ckptdir = os.path.join(opt.ckptdir, name)
|
222 |
-
cfgdir = os.path.join(logdir, "configs")
|
223 |
-
|
224 |
-
if opt.resume:
|
225 |
-
ckpt = os.path.join(ckptdir, "last.ckpt")
|
226 |
-
opt.resume_from_checkpoint = ckpt
|
227 |
-
opt.finetune_from = "" # disable finetune checkpoint
|
228 |
-
|
229 |
-
seed_everything(opt.seed)
|
230 |
-
|
231 |
-
###################config#####################
|
232 |
-
config = OmegaConf.load(opt.base) # loade default configs
|
233 |
-
lightning_config = config.lightning
|
234 |
-
trainer_config = config.lightning.trainer
|
235 |
-
for k in trainer_args(opt): # overwrite trainer configs
|
236 |
-
trainer_config[k] = getattr(opt, k)
|
237 |
-
|
238 |
-
###################trainer#####################
|
239 |
-
# training framework
|
240 |
-
gpuinfo = trainer_config["gpus"]
|
241 |
-
rank_zero_print(f"Running on GPUs {gpuinfo}")
|
242 |
-
ngpu = len(trainer_config.gpus.strip(",").split(','))
|
243 |
-
trainer_config['devices'] = ngpu
|
244 |
-
|
245 |
-
###################model#####################
|
246 |
-
model = instantiate_from_config(config.model)
|
247 |
-
model.cpu()
|
248 |
-
# load stable diffusion parameters
|
249 |
-
if opt.finetune_from != "":
|
250 |
-
load_pretrain_stable_diffusion(model, opt.finetune_from)
|
251 |
-
|
252 |
-
###################logger#####################
|
253 |
-
# default logger configs
|
254 |
-
default_logger_cfg = {"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
255 |
-
"params": {"save_dir": logdir, "name": "tensorboard_logs", }}
|
256 |
-
logger_cfg = OmegaConf.create(default_logger_cfg)
|
257 |
-
logger = instantiate_from_config(logger_cfg)
|
258 |
-
|
259 |
-
###################callbacks#####################
|
260 |
-
# default ckpt callbacks
|
261 |
-
default_modelckpt_cfg = {"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
262 |
-
"params": {"dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, "every_n_train_steps": 5000}}
|
263 |
-
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, get_optional_dict("modelcheckpoint", lightning_config)) # overwrite checkpoint configs
|
264 |
-
default_modelckpt_cfg_repeat = {"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
265 |
-
"params": {"dirpath": ckptdir, "filename": "{step:08}", "verbose": True, "save_last": False, "every_n_train_steps": 5000, "save_top_k": -1}}
|
266 |
-
modelckpt_cfg_repeat = OmegaConf.merge(default_modelckpt_cfg_repeat)
|
267 |
-
|
268 |
-
# add callback which sets up log directory
|
269 |
-
default_callbacks_cfg = {
|
270 |
-
"setup_callback": {
|
271 |
-
"target": "train_syncdreamer.SetupCallback",
|
272 |
-
"params": {"resume": opt.resume, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config}
|
273 |
-
},
|
274 |
-
"learning_rate_logger": {
|
275 |
-
"target": "train_syncdreamer.LearningRateMonitor",
|
276 |
-
"params": {"logging_interval": "step"}
|
277 |
-
},
|
278 |
-
"cuda_callback": {"target": "train_syncdreamer.CUDACallback"},
|
279 |
-
}
|
280 |
-
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, get_optional_dict("callbacks", lightning_config))
|
281 |
-
callbacks_cfg['model_ckpt'] = modelckpt_cfg # add checkpoint
|
282 |
-
callbacks_cfg['model_ckpt_repeat'] = modelckpt_cfg_repeat # add checkpoint
|
283 |
-
callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] # construct all callbacks
|
284 |
-
if opt.resume:
|
285 |
-
callbacks.append(ResumeCallBacks())
|
286 |
-
|
287 |
-
trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config,
|
288 |
-
accelerator='cuda', strategy=DDPStrategy(find_unused_parameters=False), logger=logger, callbacks=callbacks)
|
289 |
-
trainer.logdir = logdir
|
290 |
-
|
291 |
-
###################data#####################
|
292 |
-
config.data.params.seed = opt.seed
|
293 |
-
data = instantiate_from_config(config.data)
|
294 |
-
data.prepare_data()
|
295 |
-
data.setup('fit')
|
296 |
-
|
297 |
-
####################lr#####################
|
298 |
-
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
299 |
-
accumulate_grad_batches = trainer_config.accumulate_grad_batches if hasattr(trainer_config, "trainer_config") else 1
|
300 |
-
rank_zero_print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
301 |
-
model.learning_rate = base_lr
|
302 |
-
rank_zero_print("++++ NOT USING LR SCALING ++++")
|
303 |
-
rank_zero_print(f"Setting learning rate to {model.learning_rate:.2e}")
|
304 |
-
model.image_dir = logdir # used in output images during training
|
305 |
-
|
306 |
-
# run
|
307 |
-
trainer.fit(model, data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|