Spaces:
Runtime error
Runtime error
Upload util/flexicubes.py with huggingface_hub
Browse files- util/flexicubes.py +579 -579
util/flexicubes.py
CHANGED
@@ -1,579 +1,579 @@
|
|
1 |
-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
-
import torch
|
9 |
-
from util.tables import *
|
10 |
-
|
11 |
-
__all__ = [
|
12 |
-
'FlexiCubes'
|
13 |
-
]
|
14 |
-
|
15 |
-
|
16 |
-
class FlexiCubes:
|
17 |
-
"""
|
18 |
-
This class implements the FlexiCubes method for extracting meshes from scalar fields.
|
19 |
-
It maintains a series of lookup tables and indices to support the mesh extraction process.
|
20 |
-
FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
|
21 |
-
the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
|
22 |
-
the surface representation through gradient-based optimization.
|
23 |
-
|
24 |
-
During instantiation, the class loads DMC tables from a file and transforms them into
|
25 |
-
PyTorch tensors on the specified device.
|
26 |
-
|
27 |
-
Attributes:
|
28 |
-
device (str): Specifies the computational device (default is "cuda").
|
29 |
-
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
-
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
-
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
32 |
-
the 256 MC configurations.
|
33 |
-
check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
|
34 |
-
of the DMC configurations.
|
35 |
-
tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
|
36 |
-
quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
|
37 |
-
along one diagonal.
|
38 |
-
quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
|
39 |
-
two triangles along the other diagonal.
|
40 |
-
quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
|
41 |
-
during training by connecting all edges to their midpoints.
|
42 |
-
cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
|
43 |
-
eight corners in 3D space, ordered starting from the origin (0,0,0),
|
44 |
-
moving along the x-axis, then y-axis, and finally z-axis.
|
45 |
-
Used as a blueprint for generating a voxel grid.
|
46 |
-
cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
|
47 |
-
to retrieve the case id.
|
48 |
-
cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
|
49 |
-
Used to retrieve edge vertices in DMC.
|
50 |
-
edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
|
51 |
-
their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
|
52 |
-
first edge is oriented along the x-axis.
|
53 |
-
dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
|
54 |
-
across four adjacent cubes to the shared faces of these cubes. For instance,
|
55 |
-
dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
|
56 |
-
the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
|
57 |
-
This tensor is only utilized during isosurface tetrahedralization.
|
58 |
-
adj_pairs (torch.Tensor):
|
59 |
-
A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
|
60 |
-
qef_reg_scale (float):
|
61 |
-
The scaling factor applied to the regularization loss to prevent issues with singularity
|
62 |
-
when solving the QEF. This parameter is only used when a 'grad_func' is specified.
|
63 |
-
weight_scale (float):
|
64 |
-
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
-
"""
|
66 |
-
|
67 |
-
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
|
68 |
-
|
69 |
-
self.device = device
|
70 |
-
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
-
self.num_vd_table = torch.tensor(num_vd_table,
|
72 |
-
dtype=torch.long, device=device, requires_grad=False)
|
73 |
-
self.check_table = torch.tensor(
|
74 |
-
check_table,
|
75 |
-
dtype=torch.long, device=device, requires_grad=False)
|
76 |
-
|
77 |
-
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
78 |
-
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
79 |
-
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
80 |
-
self.quad_split_train = torch.tensor(
|
81 |
-
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
82 |
-
|
83 |
-
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
84 |
-
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
85 |
-
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
86 |
-
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
87 |
-
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
88 |
-
|
89 |
-
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
90 |
-
dtype=torch.long, device=device)
|
91 |
-
self.dir_faces_table = torch.tensor([
|
92 |
-
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
93 |
-
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
94 |
-
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
95 |
-
], dtype=torch.long, device=device)
|
96 |
-
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
97 |
-
self.qef_reg_scale = qef_reg_scale
|
98 |
-
self.weight_scale = weight_scale
|
99 |
-
|
100 |
-
def construct_voxel_grid(self, res):
|
101 |
-
"""
|
102 |
-
Generates a voxel grid based on the specified resolution.
|
103 |
-
|
104 |
-
Args:
|
105 |
-
res (int or list[int]): The resolution of the voxel grid. If an integer
|
106 |
-
is provided, it is used for all three dimensions. If a list or tuple
|
107 |
-
of 3 integers is provided, they define the resolution for the x,
|
108 |
-
y, and z dimensions respectively.
|
109 |
-
|
110 |
-
Returns:
|
111 |
-
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
|
112 |
-
cube corners (index into vertices) of the constructed voxel grid.
|
113 |
-
The vertices are centered at the origin, with the length of each
|
114 |
-
dimension in the grid being one.
|
115 |
-
"""
|
116 |
-
base_cube_f = torch.arange(8).to(self.device)
|
117 |
-
if isinstance(res, int):
|
118 |
-
res = (res, res, res)
|
119 |
-
voxel_grid_template = torch.ones(res, device=self.device)
|
120 |
-
|
121 |
-
res = torch.tensor([res], dtype=torch.float, device=self.device)
|
122 |
-
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
|
123 |
-
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
|
124 |
-
cubes = (base_cube_f.unsqueeze(0) +
|
125 |
-
torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
|
126 |
-
|
127 |
-
verts_rounded = torch.round(verts * 10**5) / (10**5)
|
128 |
-
verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
|
129 |
-
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
|
130 |
-
|
131 |
-
return verts_unique - 0.5, cubes
|
132 |
-
|
133 |
-
def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
|
134 |
-
gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
|
135 |
-
r"""
|
136 |
-
Main function for mesh extraction from scalar field using FlexiCubes. This function converts
|
137 |
-
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
|
138 |
-
to triangle or tetrahedral meshes using a differentiable operation as described in
|
139 |
-
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
|
140 |
-
mesh quality and geometric fidelity by adjusting the surface representation based on gradient
|
141 |
-
optimization. The output surface is differentiable with respect to the input vertex positions,
|
142 |
-
scalar field values, and weight parameters.
|
143 |
-
|
144 |
-
If you intend to extract a surface mesh from a fixed Signed Distance Field without the
|
145 |
-
optimization of parameters, it is suggested to provide the "grad_func" which should
|
146 |
-
return the surface gradient at any given 3D position. When grad_func is provided, the process
|
147 |
-
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
|
148 |
-
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
|
149 |
-
Please note, this approach is non-differentiable.
|
150 |
-
|
151 |
-
For more details and example usage in optimization, refer to the
|
152 |
-
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
|
153 |
-
|
154 |
-
Args:
|
155 |
-
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
|
156 |
-
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
|
157 |
-
denote that the corresponding vertex resides inside the isosurface. This affects
|
158 |
-
the directions of the extracted triangle faces and volume to be tetrahedralized.
|
159 |
-
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
|
160 |
-
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
|
161 |
-
is used for all three dimensions. If a list or tuple of 3 integers is provided, they
|
162 |
-
specify the resolution for the x, y, and z dimensions respectively.
|
163 |
-
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
|
164 |
-
vertices positioning. Defaults to uniform value for all edges.
|
165 |
-
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
|
166 |
-
vertices positioning. Defaults to uniform value for all vertices.
|
167 |
-
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
|
168 |
-
quadrilaterals into triangles. Defaults to uniform value for all cubes.
|
169 |
-
training (bool, optional): If set to True, applies differentiable quad splitting for
|
170 |
-
training. Defaults to False.
|
171 |
-
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
|
172 |
-
outputs a triangular mesh. Defaults to False.
|
173 |
-
grad_func (callable, optional): A function to compute the surface gradient at specified
|
174 |
-
3D positions (input: Nx3 positions). The function should return gradients as an Nx3
|
175 |
-
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
|
176 |
-
|
177 |
-
Returns:
|
178 |
-
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
|
179 |
-
- Vertices for the extracted triangular/tetrahedral mesh.
|
180 |
-
- Faces for the extracted triangular/tetrahedral mesh.
|
181 |
-
- Regularizer L_dev, computed per dual vertex.
|
182 |
-
|
183 |
-
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
|
184 |
-
https://research.nvidia.com/labs/toronto-ai/flexicubes/
|
185 |
-
.. _Manifold Dual Contouring:
|
186 |
-
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
|
187 |
-
"""
|
188 |
-
|
189 |
-
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
|
190 |
-
if surf_cubes.sum() == 0:
|
191 |
-
return torch.zeros(
|
192 |
-
(0, 3),
|
193 |
-
device=self.device), torch.zeros(
|
194 |
-
(0, 4),
|
195 |
-
dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
|
196 |
-
(0, 3),
|
197 |
-
dtype=torch.long, device=self.device), torch.zeros(
|
198 |
-
(0),
|
199 |
-
device=self.device)
|
200 |
-
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
|
201 |
-
|
202 |
-
case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
|
203 |
-
|
204 |
-
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
|
205 |
-
|
206 |
-
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
|
207 |
-
x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
|
208 |
-
vertices, faces, s_edges, edge_indices = self._triangulate(
|
209 |
-
s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
|
210 |
-
if not output_tetmesh:
|
211 |
-
return vertices, faces, L_dev
|
212 |
-
else:
|
213 |
-
vertices, tets = self._tetrahedralize(
|
214 |
-
x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
215 |
-
surf_cubes, training)
|
216 |
-
return vertices, tets, L_dev
|
217 |
-
|
218 |
-
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
219 |
-
"""
|
220 |
-
Regularizer L_dev as in Equation 8
|
221 |
-
"""
|
222 |
-
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
223 |
-
mean_l2 = torch.zeros_like(vd[:, 0])
|
224 |
-
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
225 |
-
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
226 |
-
return mad
|
227 |
-
|
228 |
-
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
|
229 |
-
"""
|
230 |
-
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
231 |
-
"""
|
232 |
-
n_cubes = surf_cubes.shape[0]
|
233 |
-
|
234 |
-
if beta_fx12 is not None:
|
235 |
-
beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
|
236 |
-
else:
|
237 |
-
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
238 |
-
|
239 |
-
if alpha_fx8 is not None:
|
240 |
-
alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
|
241 |
-
else:
|
242 |
-
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
243 |
-
|
244 |
-
if gamma_f is not None:
|
245 |
-
gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
|
246 |
-
else:
|
247 |
-
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
248 |
-
|
249 |
-
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
|
250 |
-
|
251 |
-
@torch.no_grad()
|
252 |
-
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
253 |
-
"""
|
254 |
-
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
255 |
-
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
256 |
-
supplementary material. It should be noted that this function assumes a regular grid.
|
257 |
-
"""
|
258 |
-
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
259 |
-
|
260 |
-
problem_config = self.check_table.to(self.device)[case_ids]
|
261 |
-
to_check = problem_config[..., 0] == 1
|
262 |
-
problem_config = problem_config[to_check]
|
263 |
-
if not isinstance(res, (list, tuple)):
|
264 |
-
res = [res, res, res]
|
265 |
-
|
266 |
-
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
267 |
-
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
268 |
-
# This allows efficient checking on adjacent cubes.
|
269 |
-
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
270 |
-
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
271 |
-
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
272 |
-
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
273 |
-
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
274 |
-
|
275 |
-
within_range = (
|
276 |
-
vol_idx_problem_adj[..., 0] >= 0) & (
|
277 |
-
vol_idx_problem_adj[..., 0] < res[0]) & (
|
278 |
-
vol_idx_problem_adj[..., 1] >= 0) & (
|
279 |
-
vol_idx_problem_adj[..., 1] < res[1]) & (
|
280 |
-
vol_idx_problem_adj[..., 2] >= 0) & (
|
281 |
-
vol_idx_problem_adj[..., 2] < res[2])
|
282 |
-
|
283 |
-
vol_idx_problem = vol_idx_problem[within_range]
|
284 |
-
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
285 |
-
problem_config = problem_config[within_range]
|
286 |
-
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
287 |
-
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
288 |
-
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
289 |
-
to_invert = (problem_config_adj[..., 0] == 1)
|
290 |
-
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
291 |
-
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
292 |
-
return case_ids
|
293 |
-
|
294 |
-
@torch.no_grad()
|
295 |
-
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
|
296 |
-
"""
|
297 |
-
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
298 |
-
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
299 |
-
and marks the cube edges with this index.
|
300 |
-
"""
|
301 |
-
occ_n = s_n < 0
|
302 |
-
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
303 |
-
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
304 |
-
|
305 |
-
unique_edges = unique_edges.long()
|
306 |
-
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
307 |
-
|
308 |
-
surf_edges_mask = mask_edges[_idx_map]
|
309 |
-
counts = counts[_idx_map]
|
310 |
-
|
311 |
-
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
|
312 |
-
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
|
313 |
-
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
314 |
-
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
315 |
-
idx_map = mapping[_idx_map]
|
316 |
-
surf_edges = unique_edges[mask_edges]
|
317 |
-
return surf_edges, idx_map, counts, surf_edges_mask
|
318 |
-
|
319 |
-
@torch.no_grad()
|
320 |
-
def _identify_surf_cubes(self, s_n, cube_fx8):
|
321 |
-
"""
|
322 |
-
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
323 |
-
all corners are not identical.
|
324 |
-
"""
|
325 |
-
occ_n = s_n < 0
|
326 |
-
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
327 |
-
_occ_sum = torch.sum(occ_fx8, -1)
|
328 |
-
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
329 |
-
return surf_cubes, occ_fx8
|
330 |
-
|
331 |
-
def _linear_interp(self, edges_weight, edges_x):
|
332 |
-
"""
|
333 |
-
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
334 |
-
"""
|
335 |
-
edge_dim = edges_weight.dim() - 2
|
336 |
-
assert edges_weight.shape[edge_dim] == 2
|
337 |
-
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
338 |
-
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
|
339 |
-
denominator = edges_weight.sum(edge_dim)
|
340 |
-
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
341 |
-
return ue
|
342 |
-
|
343 |
-
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
|
344 |
-
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
345 |
-
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
346 |
-
c_bx3 = c_bx3.reshape(-1, 3)
|
347 |
-
A = norm_bxnx3
|
348 |
-
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
349 |
-
|
350 |
-
A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
351 |
-
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
|
352 |
-
A = torch.cat([A, A_reg], 1)
|
353 |
-
B = torch.cat([B, B_reg], 1)
|
354 |
-
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
355 |
-
return dual_verts
|
356 |
-
|
357 |
-
def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
|
358 |
-
"""
|
359 |
-
Computes the location of dual vertices as described in Section 4.2
|
360 |
-
"""
|
361 |
-
alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
362 |
-
surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
363 |
-
surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
364 |
-
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
365 |
-
|
366 |
-
idx_map = idx_map.reshape(-1, 12)
|
367 |
-
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
368 |
-
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
369 |
-
|
370 |
-
total_num_vd = 0
|
371 |
-
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
372 |
-
if grad_func is not None:
|
373 |
-
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
|
374 |
-
vd = []
|
375 |
-
for num in torch.unique(num_vd):
|
376 |
-
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
377 |
-
curr_num_vd = cur_cubes.sum() * num
|
378 |
-
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
379 |
-
curr_edge_group_to_vd = torch.arange(
|
380 |
-
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
381 |
-
total_num_vd += curr_num_vd
|
382 |
-
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
383 |
-
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
384 |
-
|
385 |
-
curr_mask = (curr_edge_group != -1)
|
386 |
-
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
387 |
-
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
388 |
-
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
389 |
-
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
390 |
-
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
391 |
-
|
392 |
-
if grad_func is not None:
|
393 |
-
with torch.no_grad():
|
394 |
-
cube_e_verts_idx = idx_map[cur_cubes]
|
395 |
-
curr_edge_group[~curr_mask] = 0
|
396 |
-
|
397 |
-
verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
|
398 |
-
verts_group_idx[verts_group_idx == -1] = 0
|
399 |
-
verts_group_pos = torch.index_select(
|
400 |
-
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
|
401 |
-
v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
|
402 |
-
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
|
403 |
-
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
|
404 |
-
|
405 |
-
normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
|
406 |
-
-1, num.item(), 7,
|
407 |
-
3)
|
408 |
-
curr_mask = curr_mask.squeeze(2)
|
409 |
-
vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
|
410 |
-
verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
|
411 |
-
edge_group = torch.cat(edge_group)
|
412 |
-
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
413 |
-
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
414 |
-
vd_num_edges = torch.cat(vd_num_edges)
|
415 |
-
vd_gamma = torch.cat(vd_gamma)
|
416 |
-
|
417 |
-
if grad_func is not None:
|
418 |
-
vd = torch.cat(vd)
|
419 |
-
L_dev = torch.zeros([1], device=self.device)
|
420 |
-
else:
|
421 |
-
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
422 |
-
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
423 |
-
|
424 |
-
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
425 |
-
|
426 |
-
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
427 |
-
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
428 |
-
|
429 |
-
zero_crossing_group = torch.index_select(
|
430 |
-
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
431 |
-
|
432 |
-
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
433 |
-
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
434 |
-
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
435 |
-
|
436 |
-
beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
|
437 |
-
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
438 |
-
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
439 |
-
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
440 |
-
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
441 |
-
|
442 |
-
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
443 |
-
|
444 |
-
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
445 |
-
12 + edge_group, src=v_idx[edge_group_to_vd])
|
446 |
-
|
447 |
-
return vd, L_dev, vd_gamma, vd_idx_map
|
448 |
-
|
449 |
-
def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
|
450 |
-
"""
|
451 |
-
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
452 |
-
triangles based on the gamma parameter, as described in Section 4.3.
|
453 |
-
"""
|
454 |
-
with torch.no_grad():
|
455 |
-
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
456 |
-
group = idx_map.reshape(-1)[group_mask]
|
457 |
-
vd_idx = vd_idx_map[group_mask]
|
458 |
-
edge_indices, indices = torch.sort(group, stable=True)
|
459 |
-
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
460 |
-
|
461 |
-
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
462 |
-
s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
463 |
-
flip_mask = s_edges[:, 0] > 0
|
464 |
-
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
465 |
-
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
466 |
-
if grad_func is not None:
|
467 |
-
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
|
468 |
-
with torch.no_grad():
|
469 |
-
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
|
470 |
-
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
471 |
-
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
|
472 |
-
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
|
473 |
-
else:
|
474 |
-
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
475 |
-
gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
476 |
-
0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
|
477 |
-
gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
478 |
-
1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
|
479 |
-
if not training:
|
480 |
-
mask = (gamma_02 > gamma_13).squeeze(1)
|
481 |
-
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
482 |
-
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
483 |
-
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
484 |
-
faces = faces.reshape(-1, 3)
|
485 |
-
else:
|
486 |
-
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
487 |
-
vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
|
488 |
-
torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
|
489 |
-
vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
|
490 |
-
torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
|
491 |
-
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
492 |
-
vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
|
493 |
-
weight_sum.unsqueeze(-1)).squeeze(1)
|
494 |
-
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
495 |
-
vd = torch.cat([vd, vd_center])
|
496 |
-
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
497 |
-
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
498 |
-
return vd, faces, s_edges, edge_indices
|
499 |
-
|
500 |
-
def _tetrahedralize(
|
501 |
-
self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
502 |
-
surf_cubes, training):
|
503 |
-
"""
|
504 |
-
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
|
505 |
-
"""
|
506 |
-
occ_n = s_n < 0
|
507 |
-
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
508 |
-
occ_sum = torch.sum(occ_fx8, -1)
|
509 |
-
|
510 |
-
inside_verts = x_nx3[occ_n]
|
511 |
-
mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
|
512 |
-
mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
|
513 |
-
"""
|
514 |
-
For each grid edge connecting two grid vertices with different
|
515 |
-
signs, we first form a four-sided pyramid by connecting one
|
516 |
-
of the grid vertices with four mesh vertices that correspond
|
517 |
-
to the grid edge and then subdivide the pyramid into two tetrahedra
|
518 |
-
"""
|
519 |
-
inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
|
520 |
-
s_edges < 0]]
|
521 |
-
if not training:
|
522 |
-
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
|
523 |
-
else:
|
524 |
-
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
|
525 |
-
|
526 |
-
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
|
527 |
-
"""
|
528 |
-
For each grid edge connecting two grid vertices with the
|
529 |
-
same sign, the tetrahedron is formed by the two grid vertices
|
530 |
-
and two vertices in consecutive adjacent cells
|
531 |
-
"""
|
532 |
-
inside_cubes = (occ_sum == 8)
|
533 |
-
inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
|
534 |
-
inside_cubes_center_idx = torch.arange(
|
535 |
-
inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
|
536 |
-
|
537 |
-
surface_n_inside_cubes = surf_cubes | inside_cubes
|
538 |
-
edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
|
539 |
-
dtype=torch.long, device=x_nx3.device) * -1
|
540 |
-
surf_cubes = surf_cubes[surface_n_inside_cubes]
|
541 |
-
inside_cubes = inside_cubes[surface_n_inside_cubes]
|
542 |
-
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
|
543 |
-
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
|
544 |
-
|
545 |
-
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
|
546 |
-
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
547 |
-
unique_edges = unique_edges.long()
|
548 |
-
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
|
549 |
-
mask = mask_edges[_idx_map]
|
550 |
-
counts = counts[_idx_map]
|
551 |
-
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
|
552 |
-
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
|
553 |
-
idx_map = mapping[_idx_map]
|
554 |
-
|
555 |
-
group_mask = (counts == 4) & mask
|
556 |
-
group = idx_map.reshape(-1)[group_mask]
|
557 |
-
edge_indices, indices = torch.sort(group)
|
558 |
-
cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
|
559 |
-
device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
|
560 |
-
edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
|
561 |
-
0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
|
562 |
-
# Identify the face shared by the adjacent cells.
|
563 |
-
cube_idx_4 = cube_idx[indices].reshape(-1, 4)
|
564 |
-
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
|
565 |
-
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
|
566 |
-
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
|
567 |
-
# Identify an edge of the face with different signs and
|
568 |
-
# select the mesh vertex corresponding to the identified edge.
|
569 |
-
case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
|
570 |
-
case_ids_expand[surf_cubes] = case_ids
|
571 |
-
cases = case_ids_expand[cube_idx_4x2]
|
572 |
-
quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
|
573 |
-
mask = (quad_edge == -1).sum(-1) == 0
|
574 |
-
inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
|
575 |
-
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
|
576 |
-
|
577 |
-
tets = torch.cat([tets_surface, tets_inside])
|
578 |
-
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
-
return vertices, tets
|
|
|
1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
import torch
|
9 |
+
from util.tables import *
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'FlexiCubes'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
class FlexiCubes:
|
17 |
+
"""
|
18 |
+
This class implements the FlexiCubes method for extracting meshes from scalar fields.
|
19 |
+
It maintains a series of lookup tables and indices to support the mesh extraction process.
|
20 |
+
FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
|
21 |
+
the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
|
22 |
+
the surface representation through gradient-based optimization.
|
23 |
+
|
24 |
+
During instantiation, the class loads DMC tables from a file and transforms them into
|
25 |
+
PyTorch tensors on the specified device.
|
26 |
+
|
27 |
+
Attributes:
|
28 |
+
device (str): Specifies the computational device (default is "cuda").
|
29 |
+
dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
|
30 |
+
associated with each dual vertex in 256 Marching Cubes (MC) configurations.
|
31 |
+
num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
|
32 |
+
the 256 MC configurations.
|
33 |
+
check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
|
34 |
+
of the DMC configurations.
|
35 |
+
tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
|
36 |
+
quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
|
37 |
+
along one diagonal.
|
38 |
+
quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
|
39 |
+
two triangles along the other diagonal.
|
40 |
+
quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
|
41 |
+
during training by connecting all edges to their midpoints.
|
42 |
+
cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
|
43 |
+
eight corners in 3D space, ordered starting from the origin (0,0,0),
|
44 |
+
moving along the x-axis, then y-axis, and finally z-axis.
|
45 |
+
Used as a blueprint for generating a voxel grid.
|
46 |
+
cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
|
47 |
+
to retrieve the case id.
|
48 |
+
cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
|
49 |
+
Used to retrieve edge vertices in DMC.
|
50 |
+
edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
|
51 |
+
their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
|
52 |
+
first edge is oriented along the x-axis.
|
53 |
+
dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
|
54 |
+
across four adjacent cubes to the shared faces of these cubes. For instance,
|
55 |
+
dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
|
56 |
+
the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
|
57 |
+
This tensor is only utilized during isosurface tetrahedralization.
|
58 |
+
adj_pairs (torch.Tensor):
|
59 |
+
A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
|
60 |
+
qef_reg_scale (float):
|
61 |
+
The scaling factor applied to the regularization loss to prevent issues with singularity
|
62 |
+
when solving the QEF. This parameter is only used when a 'grad_func' is specified.
|
63 |
+
weight_scale (float):
|
64 |
+
The scale of weights in FlexiCubes. Should be between 0 and 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
|
68 |
+
|
69 |
+
self.device = device
|
70 |
+
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
71 |
+
self.num_vd_table = torch.tensor(num_vd_table,
|
72 |
+
dtype=torch.long, device=device, requires_grad=False)
|
73 |
+
self.check_table = torch.tensor(
|
74 |
+
check_table,
|
75 |
+
dtype=torch.long, device=device, requires_grad=False)
|
76 |
+
|
77 |
+
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
78 |
+
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
79 |
+
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
80 |
+
self.quad_split_train = torch.tensor(
|
81 |
+
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
82 |
+
|
83 |
+
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
84 |
+
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
85 |
+
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
86 |
+
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
87 |
+
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
88 |
+
|
89 |
+
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
90 |
+
dtype=torch.long, device=device)
|
91 |
+
self.dir_faces_table = torch.tensor([
|
92 |
+
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
93 |
+
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
94 |
+
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
95 |
+
], dtype=torch.long, device=device)
|
96 |
+
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
97 |
+
self.qef_reg_scale = qef_reg_scale
|
98 |
+
self.weight_scale = weight_scale
|
99 |
+
|
100 |
+
def construct_voxel_grid(self, res):
|
101 |
+
"""
|
102 |
+
Generates a voxel grid based on the specified resolution.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
res (int or list[int]): The resolution of the voxel grid. If an integer
|
106 |
+
is provided, it is used for all three dimensions. If a list or tuple
|
107 |
+
of 3 integers is provided, they define the resolution for the x,
|
108 |
+
y, and z dimensions respectively.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
(torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
|
112 |
+
cube corners (index into vertices) of the constructed voxel grid.
|
113 |
+
The vertices are centered at the origin, with the length of each
|
114 |
+
dimension in the grid being one.
|
115 |
+
"""
|
116 |
+
base_cube_f = torch.arange(8).to(self.device)
|
117 |
+
if isinstance(res, int):
|
118 |
+
res = (res, res, res)
|
119 |
+
voxel_grid_template = torch.ones(res, device=self.device)
|
120 |
+
|
121 |
+
res = torch.tensor([res], dtype=torch.float, device=self.device)
|
122 |
+
coords = torch.nonzero(voxel_grid_template).float() / res # N, 3
|
123 |
+
verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
|
124 |
+
cubes = (base_cube_f.unsqueeze(0) +
|
125 |
+
torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
|
126 |
+
|
127 |
+
verts_rounded = torch.round(verts * 10**5) / (10**5)
|
128 |
+
verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
|
129 |
+
cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
|
130 |
+
|
131 |
+
return verts_unique - 0.5, cubes
|
132 |
+
|
133 |
+
def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
|
134 |
+
gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
|
135 |
+
r"""
|
136 |
+
Main function for mesh extraction from scalar field using FlexiCubes. This function converts
|
137 |
+
discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
|
138 |
+
to triangle or tetrahedral meshes using a differentiable operation as described in
|
139 |
+
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
|
140 |
+
mesh quality and geometric fidelity by adjusting the surface representation based on gradient
|
141 |
+
optimization. The output surface is differentiable with respect to the input vertex positions,
|
142 |
+
scalar field values, and weight parameters.
|
143 |
+
|
144 |
+
If you intend to extract a surface mesh from a fixed Signed Distance Field without the
|
145 |
+
optimization of parameters, it is suggested to provide the "grad_func" which should
|
146 |
+
return the surface gradient at any given 3D position. When grad_func is provided, the process
|
147 |
+
to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
|
148 |
+
described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
|
149 |
+
Please note, this approach is non-differentiable.
|
150 |
+
|
151 |
+
For more details and example usage in optimization, refer to the
|
152 |
+
`Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
|
156 |
+
s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
|
157 |
+
denote that the corresponding vertex resides inside the isosurface. This affects
|
158 |
+
the directions of the extracted triangle faces and volume to be tetrahedralized.
|
159 |
+
cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
|
160 |
+
res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
|
161 |
+
is used for all three dimensions. If a list or tuple of 3 integers is provided, they
|
162 |
+
specify the resolution for the x, y, and z dimensions respectively.
|
163 |
+
beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
|
164 |
+
vertices positioning. Defaults to uniform value for all edges.
|
165 |
+
alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
|
166 |
+
vertices positioning. Defaults to uniform value for all vertices.
|
167 |
+
gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
|
168 |
+
quadrilaterals into triangles. Defaults to uniform value for all cubes.
|
169 |
+
training (bool, optional): If set to True, applies differentiable quad splitting for
|
170 |
+
training. Defaults to False.
|
171 |
+
output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
|
172 |
+
outputs a triangular mesh. Defaults to False.
|
173 |
+
grad_func (callable, optional): A function to compute the surface gradient at specified
|
174 |
+
3D positions (input: Nx3 positions). The function should return gradients as an Nx3
|
175 |
+
tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
(torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
|
179 |
+
- Vertices for the extracted triangular/tetrahedral mesh.
|
180 |
+
- Faces for the extracted triangular/tetrahedral mesh.
|
181 |
+
- Regularizer L_dev, computed per dual vertex.
|
182 |
+
|
183 |
+
.. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
|
184 |
+
https://research.nvidia.com/labs/toronto-ai/flexicubes/
|
185 |
+
.. _Manifold Dual Contouring:
|
186 |
+
https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
|
187 |
+
"""
|
188 |
+
|
189 |
+
surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
|
190 |
+
if surf_cubes.sum() == 0:
|
191 |
+
return torch.zeros(
|
192 |
+
(0, 3),
|
193 |
+
device=self.device), torch.zeros(
|
194 |
+
(0, 4),
|
195 |
+
dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
|
196 |
+
(0, 3),
|
197 |
+
dtype=torch.long, device=self.device), torch.zeros(
|
198 |
+
(0),
|
199 |
+
device=self.device)
|
200 |
+
beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
|
201 |
+
|
202 |
+
case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
|
203 |
+
|
204 |
+
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
|
205 |
+
|
206 |
+
vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
|
207 |
+
x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
|
208 |
+
vertices, faces, s_edges, edge_indices = self._triangulate(
|
209 |
+
s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
|
210 |
+
if not output_tetmesh:
|
211 |
+
return vertices, faces, L_dev
|
212 |
+
else:
|
213 |
+
vertices, tets = self._tetrahedralize(
|
214 |
+
x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
215 |
+
surf_cubes, training)
|
216 |
+
return vertices, tets, L_dev
|
217 |
+
|
218 |
+
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
219 |
+
"""
|
220 |
+
Regularizer L_dev as in Equation 8
|
221 |
+
"""
|
222 |
+
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
223 |
+
mean_l2 = torch.zeros_like(vd[:, 0])
|
224 |
+
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
225 |
+
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
226 |
+
return mad
|
227 |
+
|
228 |
+
def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
|
229 |
+
"""
|
230 |
+
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
231 |
+
"""
|
232 |
+
n_cubes = surf_cubes.shape[0]
|
233 |
+
|
234 |
+
if beta_fx12 is not None:
|
235 |
+
beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
|
236 |
+
else:
|
237 |
+
beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
238 |
+
|
239 |
+
if alpha_fx8 is not None:
|
240 |
+
alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
|
241 |
+
else:
|
242 |
+
alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
243 |
+
|
244 |
+
if gamma_f is not None:
|
245 |
+
gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
|
246 |
+
else:
|
247 |
+
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
248 |
+
|
249 |
+
return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
|
250 |
+
|
251 |
+
@torch.no_grad()
|
252 |
+
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
253 |
+
"""
|
254 |
+
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
255 |
+
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
256 |
+
supplementary material. It should be noted that this function assumes a regular grid.
|
257 |
+
"""
|
258 |
+
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
259 |
+
|
260 |
+
problem_config = self.check_table.to(self.device)[case_ids]
|
261 |
+
to_check = problem_config[..., 0] == 1
|
262 |
+
problem_config = problem_config[to_check]
|
263 |
+
if not isinstance(res, (list, tuple)):
|
264 |
+
res = [res, res, res]
|
265 |
+
|
266 |
+
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
267 |
+
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
268 |
+
# This allows efficient checking on adjacent cubes.
|
269 |
+
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
270 |
+
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
271 |
+
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
272 |
+
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
273 |
+
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
274 |
+
|
275 |
+
within_range = (
|
276 |
+
vol_idx_problem_adj[..., 0] >= 0) & (
|
277 |
+
vol_idx_problem_adj[..., 0] < res[0]) & (
|
278 |
+
vol_idx_problem_adj[..., 1] >= 0) & (
|
279 |
+
vol_idx_problem_adj[..., 1] < res[1]) & (
|
280 |
+
vol_idx_problem_adj[..., 2] >= 0) & (
|
281 |
+
vol_idx_problem_adj[..., 2] < res[2])
|
282 |
+
|
283 |
+
vol_idx_problem = vol_idx_problem[within_range]
|
284 |
+
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
285 |
+
problem_config = problem_config[within_range]
|
286 |
+
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
287 |
+
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
288 |
+
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
289 |
+
to_invert = (problem_config_adj[..., 0] == 1)
|
290 |
+
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
291 |
+
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
292 |
+
return case_ids
|
293 |
+
|
294 |
+
@torch.no_grad()
|
295 |
+
def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
|
296 |
+
"""
|
297 |
+
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
298 |
+
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
299 |
+
and marks the cube edges with this index.
|
300 |
+
"""
|
301 |
+
occ_n = s_n < 0
|
302 |
+
all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
303 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
304 |
+
|
305 |
+
unique_edges = unique_edges.long()
|
306 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
307 |
+
|
308 |
+
surf_edges_mask = mask_edges[_idx_map]
|
309 |
+
counts = counts[_idx_map]
|
310 |
+
|
311 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
|
312 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
|
313 |
+
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
314 |
+
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
315 |
+
idx_map = mapping[_idx_map]
|
316 |
+
surf_edges = unique_edges[mask_edges]
|
317 |
+
return surf_edges, idx_map, counts, surf_edges_mask
|
318 |
+
|
319 |
+
@torch.no_grad()
|
320 |
+
def _identify_surf_cubes(self, s_n, cube_fx8):
|
321 |
+
"""
|
322 |
+
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
323 |
+
all corners are not identical.
|
324 |
+
"""
|
325 |
+
occ_n = s_n < 0
|
326 |
+
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
327 |
+
_occ_sum = torch.sum(occ_fx8, -1)
|
328 |
+
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
329 |
+
return surf_cubes, occ_fx8
|
330 |
+
|
331 |
+
def _linear_interp(self, edges_weight, edges_x):
|
332 |
+
"""
|
333 |
+
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
334 |
+
"""
|
335 |
+
edge_dim = edges_weight.dim() - 2
|
336 |
+
assert edges_weight.shape[edge_dim] == 2
|
337 |
+
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
338 |
+
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
|
339 |
+
denominator = edges_weight.sum(edge_dim)
|
340 |
+
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
341 |
+
return ue
|
342 |
+
|
343 |
+
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
|
344 |
+
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
345 |
+
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
346 |
+
c_bx3 = c_bx3.reshape(-1, 3)
|
347 |
+
A = norm_bxnx3
|
348 |
+
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
349 |
+
|
350 |
+
A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
351 |
+
B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
|
352 |
+
A = torch.cat([A, A_reg], 1)
|
353 |
+
B = torch.cat([B, B_reg], 1)
|
354 |
+
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
355 |
+
return dual_verts
|
356 |
+
|
357 |
+
def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
|
358 |
+
"""
|
359 |
+
Computes the location of dual vertices as described in Section 4.2
|
360 |
+
"""
|
361 |
+
alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
362 |
+
surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
363 |
+
surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
364 |
+
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
365 |
+
|
366 |
+
idx_map = idx_map.reshape(-1, 12)
|
367 |
+
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
368 |
+
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
369 |
+
|
370 |
+
total_num_vd = 0
|
371 |
+
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
372 |
+
if grad_func is not None:
|
373 |
+
normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
|
374 |
+
vd = []
|
375 |
+
for num in torch.unique(num_vd):
|
376 |
+
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
377 |
+
curr_num_vd = cur_cubes.sum() * num
|
378 |
+
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
379 |
+
curr_edge_group_to_vd = torch.arange(
|
380 |
+
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
381 |
+
total_num_vd += curr_num_vd
|
382 |
+
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
383 |
+
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
384 |
+
|
385 |
+
curr_mask = (curr_edge_group != -1)
|
386 |
+
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
387 |
+
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
388 |
+
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
389 |
+
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
390 |
+
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
391 |
+
|
392 |
+
if grad_func is not None:
|
393 |
+
with torch.no_grad():
|
394 |
+
cube_e_verts_idx = idx_map[cur_cubes]
|
395 |
+
curr_edge_group[~curr_mask] = 0
|
396 |
+
|
397 |
+
verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
|
398 |
+
verts_group_idx[verts_group_idx == -1] = 0
|
399 |
+
verts_group_pos = torch.index_select(
|
400 |
+
input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
|
401 |
+
v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
|
402 |
+
curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
|
403 |
+
verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
|
404 |
+
|
405 |
+
normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
|
406 |
+
-1, num.item(), 7,
|
407 |
+
3)
|
408 |
+
curr_mask = curr_mask.squeeze(2)
|
409 |
+
vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
|
410 |
+
verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
|
411 |
+
edge_group = torch.cat(edge_group)
|
412 |
+
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
413 |
+
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
414 |
+
vd_num_edges = torch.cat(vd_num_edges)
|
415 |
+
vd_gamma = torch.cat(vd_gamma)
|
416 |
+
|
417 |
+
if grad_func is not None:
|
418 |
+
vd = torch.cat(vd)
|
419 |
+
L_dev = torch.zeros([1], device=self.device)
|
420 |
+
else:
|
421 |
+
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
422 |
+
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
423 |
+
|
424 |
+
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
425 |
+
|
426 |
+
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
427 |
+
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
428 |
+
|
429 |
+
zero_crossing_group = torch.index_select(
|
430 |
+
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
431 |
+
|
432 |
+
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
433 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
434 |
+
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
435 |
+
|
436 |
+
beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
|
437 |
+
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
438 |
+
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
439 |
+
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
440 |
+
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
441 |
+
|
442 |
+
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
443 |
+
|
444 |
+
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
445 |
+
12 + edge_group, src=v_idx[edge_group_to_vd])
|
446 |
+
|
447 |
+
return vd, L_dev, vd_gamma, vd_idx_map
|
448 |
+
|
449 |
+
def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
|
450 |
+
"""
|
451 |
+
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
452 |
+
triangles based on the gamma parameter, as described in Section 4.3.
|
453 |
+
"""
|
454 |
+
with torch.no_grad():
|
455 |
+
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
456 |
+
group = idx_map.reshape(-1)[group_mask]
|
457 |
+
vd_idx = vd_idx_map[group_mask]
|
458 |
+
edge_indices, indices = torch.sort(group, stable=True)
|
459 |
+
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
460 |
+
|
461 |
+
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
462 |
+
s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
463 |
+
flip_mask = s_edges[:, 0] > 0
|
464 |
+
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
465 |
+
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
466 |
+
if grad_func is not None:
|
467 |
+
# when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
|
468 |
+
with torch.no_grad():
|
469 |
+
vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
|
470 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
471 |
+
gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
|
472 |
+
gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
|
473 |
+
else:
|
474 |
+
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
475 |
+
gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
476 |
+
0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
|
477 |
+
gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
|
478 |
+
1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
|
479 |
+
if not training:
|
480 |
+
mask = (gamma_02 > gamma_13).squeeze(1)
|
481 |
+
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
482 |
+
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
483 |
+
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
484 |
+
faces = faces.reshape(-1, 3)
|
485 |
+
else:
|
486 |
+
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
487 |
+
vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
|
488 |
+
torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
|
489 |
+
vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
|
490 |
+
torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
|
491 |
+
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
492 |
+
vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
|
493 |
+
weight_sum.unsqueeze(-1)).squeeze(1)
|
494 |
+
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
495 |
+
vd = torch.cat([vd, vd_center])
|
496 |
+
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
497 |
+
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
498 |
+
return vd, faces, s_edges, edge_indices
|
499 |
+
|
500 |
+
def _tetrahedralize(
|
501 |
+
self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
|
502 |
+
surf_cubes, training):
|
503 |
+
"""
|
504 |
+
Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
|
505 |
+
"""
|
506 |
+
occ_n = s_n < 0
|
507 |
+
occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
|
508 |
+
occ_sum = torch.sum(occ_fx8, -1)
|
509 |
+
|
510 |
+
inside_verts = x_nx3[occ_n]
|
511 |
+
mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
|
512 |
+
mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
|
513 |
+
"""
|
514 |
+
For each grid edge connecting two grid vertices with different
|
515 |
+
signs, we first form a four-sided pyramid by connecting one
|
516 |
+
of the grid vertices with four mesh vertices that correspond
|
517 |
+
to the grid edge and then subdivide the pyramid into two tetrahedra
|
518 |
+
"""
|
519 |
+
inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
|
520 |
+
s_edges < 0]]
|
521 |
+
if not training:
|
522 |
+
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
|
523 |
+
else:
|
524 |
+
inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
|
525 |
+
|
526 |
+
tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
|
527 |
+
"""
|
528 |
+
For each grid edge connecting two grid vertices with the
|
529 |
+
same sign, the tetrahedron is formed by the two grid vertices
|
530 |
+
and two vertices in consecutive adjacent cells
|
531 |
+
"""
|
532 |
+
inside_cubes = (occ_sum == 8)
|
533 |
+
inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
|
534 |
+
inside_cubes_center_idx = torch.arange(
|
535 |
+
inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
|
536 |
+
|
537 |
+
surface_n_inside_cubes = surf_cubes | inside_cubes
|
538 |
+
edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
|
539 |
+
dtype=torch.long, device=x_nx3.device) * -1
|
540 |
+
surf_cubes = surf_cubes[surface_n_inside_cubes]
|
541 |
+
inside_cubes = inside_cubes[surface_n_inside_cubes]
|
542 |
+
edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
|
543 |
+
edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
|
544 |
+
|
545 |
+
all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
|
546 |
+
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
547 |
+
unique_edges = unique_edges.long()
|
548 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
|
549 |
+
mask = mask_edges[_idx_map]
|
550 |
+
counts = counts[_idx_map]
|
551 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
|
552 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
|
553 |
+
idx_map = mapping[_idx_map]
|
554 |
+
|
555 |
+
group_mask = (counts == 4) & mask
|
556 |
+
group = idx_map.reshape(-1)[group_mask]
|
557 |
+
edge_indices, indices = torch.sort(group)
|
558 |
+
cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
|
559 |
+
device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
|
560 |
+
edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
|
561 |
+
0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
|
562 |
+
# Identify the face shared by the adjacent cells.
|
563 |
+
cube_idx_4 = cube_idx[indices].reshape(-1, 4)
|
564 |
+
edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
|
565 |
+
shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
|
566 |
+
cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
|
567 |
+
# Identify an edge of the face with different signs and
|
568 |
+
# select the mesh vertex corresponding to the identified edge.
|
569 |
+
case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
|
570 |
+
case_ids_expand[surf_cubes] = case_ids
|
571 |
+
cases = case_ids_expand[cube_idx_4x2]
|
572 |
+
quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
|
573 |
+
mask = (quad_edge == -1).sum(-1) == 0
|
574 |
+
inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
|
575 |
+
tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
|
576 |
+
|
577 |
+
tets = torch.cat([tets_surface, tets_inside])
|
578 |
+
vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
|
579 |
+
return vertices, tets
|