Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,272 Bytes
f3ff4f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import torch
from easydict import EasyDict as edict
from typing import Tuple, Optional
from diso import DiffDMC
from .cube2mesh import MeshExtractResult
from .utils_cube import *
from ...modules.sparse import SparseTensor
class EnhancedMarchingCubes:
def __init__(self, device="cuda"):
self.device = device
self.diffdmc = DiffDMC(dtype=torch.float32)
def __call__(self,
voxelgrid_vertices: torch.Tensor,
scalar_field: torch.Tensor,
voxelgrid_colors: Optional[torch.Tensor] = None,
training: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Enhanced Marching Cubes implementation using DiffDMC that handles deformations and colors
"""
if scalar_field.dim() == 1:
grid_size = int(round(scalar_field.shape[0] ** (1 / 3)))
scalar_field = scalar_field.reshape(grid_size, grid_size, grid_size)
elif scalar_field.dim() > 3:
scalar_field = scalar_field.squeeze()
if scalar_field.dim() != 3:
raise ValueError(f"Expected 3D array, got shape {scalar_field.shape}")
# Normalize coordinates for DiffDMC
scalar_field = scalar_field.to(self.device)
# Get deformation field if provided
deform_field = None
if voxelgrid_vertices is not None:
if voxelgrid_vertices.dim() == 2:
grid_size = int(round(voxelgrid_vertices.shape[0] ** (1 / 3)))
voxelgrid_vertices = voxelgrid_vertices.reshape(grid_size, grid_size, grid_size, 3)
deform_field = voxelgrid_vertices.to(self.device)
# Run DiffDMC
vertices, faces = self.diffdmc(
scalar_field,
deform_field,
isovalue=0.0
)
# Handle colors if provided
colors = None
if voxelgrid_colors is not None:
voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
if voxelgrid_colors.dim() == 2:
grid_size = int(round(voxelgrid_colors.shape[0] ** (1/3)))
voxelgrid_colors = voxelgrid_colors.reshape(grid_size, grid_size, grid_size, -1)
grid_positions = vertices.clone() * grid_size
grid_coords = grid_positions.long()
local_coords = grid_positions - grid_coords.float()
# Clamp coordinates to grid bounds
grid_coords = torch.clamp(grid_coords, 0, voxelgrid_colors.shape[0] - 1)
# Trilinear interpolation for colors
colors = self._interpolate_color(grid_coords, local_coords, voxelgrid_colors)
vertices = vertices * 2 - 1 # Normalize vertices to [-1, 1]
vertices /= 2.0 # Normalize vertices to [-0.5, 0.5]
# Compute deviation loss for training
deviation_loss = torch.tensor(0.0, device=self.device)
if training and deform_field is not None:
# Compute deviation between original and deformed vertices
deviation_loss = self._compute_deviation_loss(vertices, deform_field)
# faces = faces.flip(dims=[1]) # Maintain consistent face orientation
return vertices, faces, deviation_loss, colors
def _interpolate_color(self, grid_coords: torch.Tensor,
local_coords: torch.Tensor,
color_field: torch.Tensor) -> torch.Tensor:
"""
Interpolate colors using trilinear interpolation
Args:
grid_coords: (N, 3) integer grid coordinates
local_coords: (N, 3) fractional positions within grid cells
color_field: (res, res, res, C) color values
"""
x, y, z = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]
# Get corner values for each vertex
c000 = color_field[grid_coords[:, 0], grid_coords[:, 1], grid_coords[:, 2]]
c001 = color_field[grid_coords[:, 0], grid_coords[:, 1],
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
c010 = color_field[grid_coords[:, 0],
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
grid_coords[:, 2]]
c011 = color_field[grid_coords[:, 0],
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
c100 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
grid_coords[:, 1], grid_coords[:, 2]]
c101 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
grid_coords[:, 1],
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
c110 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
grid_coords[:, 2]]
c111 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
# Interpolate along x axis
c00 = c000 * (1 - x)[:, None] + c100 * x[:, None]
c01 = c001 * (1 - x)[:, None] + c101 * x[:, None]
c10 = c010 * (1 - x)[:, None] + c110 * x[:, None]
c11 = c011 * (1 - x)[:, None] + c111 * x[:, None]
# Interpolate along y axis
c0 = c00 * (1 - y)[:, None] + c10 * y[:, None]
c1 = c01 * (1 - y)[:, None] + c11 * y[:, None]
# Interpolate along z axis
colors = c0 * (1 - z)[:, None] + c1 * z[:, None]
return colors
def _compute_deviation_loss(self, vertices: torch.Tensor,
deform_field: torch.Tensor) -> torch.Tensor:
"""Compute deviation loss for training"""
# Since DiffDMC already handles the deformation, we compute the loss
# based on the magnitude of the deformation field
return torch.mean(deform_field ** 2)
class SparseFeatures2MCMesh:
def __init__(self, device="cuda", res=128, use_color=True):
super().__init__()
self.device = device
self.res = res
self.mesh_extractor = EnhancedMarchingCubes(device=device)
self.sdf_bias = -1.0 / res
verts, cube = construct_dense_grid(self.res, self.device)
self.reg_c = cube.to(self.device)
self.reg_v = verts.to(self.device)
self.use_color = use_color
self._calc_layout()
def _calc_layout(self):
LAYOUTS = {
'sdf': {'shape': (8, 1), 'size': 8},
'deform': {'shape': (8, 3), 'size': 8 * 3},
'weights': {'shape': (21,), 'size': 21}
}
if self.use_color:
'''
6 channel color including normal map
'''
LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
self.layouts = edict(LAYOUTS)
start = 0
for k, v in self.layouts.items():
v['range'] = (start, start + v['size'])
start += v['size']
self.feats_channels = start
def get_layout(self, feats: torch.Tensor, name: str):
if name not in self.layouts:
return None
return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name][
'shape'])
def __call__(self, cubefeats: SparseTensor, training=False):
coords = cubefeats.coords[:, 1:]
feats = cubefeats.feats
sdf, deform, color, weights = [self.get_layout(feats, name)
for name in ['sdf', 'deform', 'color', 'weights']]
sdf += self.sdf_bias
v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1),
training=training)
v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res + 1, sdf_init=True)
if self.use_color:
sdf_d, deform_d, colors_d = (v_attrs_d[..., 0], v_attrs_d[..., 1:4],
v_attrs_d[..., 4:])
else:
sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
colors_d = None
x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
vertices, faces, L_dev, colors = self.mesh_extractor(
voxelgrid_vertices=x_nx3,
scalar_field=sdf_d,
voxelgrid_colors=colors_d,
training=training
)
mesh = MeshExtractResult(vertices=vertices, faces=faces,
vertex_attrs=colors, res=self.res)
if training:
if mesh.success:
reg_loss += L_dev.mean() * 0.5
reg_loss += (weights[:, :20]).abs().mean() * 0.2
mesh.reg_loss = reg_loss
mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
mesh.tsdf_s = v_attrs[:, 0]
return mesh |