Stable-X's picture
Upload 288 files
f3ff4f1 verified
raw
history blame
9.27 kB
import torch
from easydict import EasyDict as edict
from typing import Tuple, Optional
from diso import DiffDMC
from .cube2mesh import MeshExtractResult
from .utils_cube import *
from ...modules.sparse import SparseTensor
class EnhancedMarchingCubes:
def __init__(self, device="cuda"):
self.device = device
self.diffdmc = DiffDMC(dtype=torch.float32)
def __call__(self,
voxelgrid_vertices: torch.Tensor,
scalar_field: torch.Tensor,
voxelgrid_colors: Optional[torch.Tensor] = None,
training: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Enhanced Marching Cubes implementation using DiffDMC that handles deformations and colors
"""
if scalar_field.dim() == 1:
grid_size = int(round(scalar_field.shape[0] ** (1 / 3)))
scalar_field = scalar_field.reshape(grid_size, grid_size, grid_size)
elif scalar_field.dim() > 3:
scalar_field = scalar_field.squeeze()
if scalar_field.dim() != 3:
raise ValueError(f"Expected 3D array, got shape {scalar_field.shape}")
# Normalize coordinates for DiffDMC
scalar_field = scalar_field.to(self.device)
# Get deformation field if provided
deform_field = None
if voxelgrid_vertices is not None:
if voxelgrid_vertices.dim() == 2:
grid_size = int(round(voxelgrid_vertices.shape[0] ** (1 / 3)))
voxelgrid_vertices = voxelgrid_vertices.reshape(grid_size, grid_size, grid_size, 3)
deform_field = voxelgrid_vertices.to(self.device)
# Run DiffDMC
vertices, faces = self.diffdmc(
scalar_field,
deform_field,
isovalue=0.0
)
# Handle colors if provided
colors = None
if voxelgrid_colors is not None:
voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
if voxelgrid_colors.dim() == 2:
grid_size = int(round(voxelgrid_colors.shape[0] ** (1/3)))
voxelgrid_colors = voxelgrid_colors.reshape(grid_size, grid_size, grid_size, -1)
grid_positions = vertices.clone() * grid_size
grid_coords = grid_positions.long()
local_coords = grid_positions - grid_coords.float()
# Clamp coordinates to grid bounds
grid_coords = torch.clamp(grid_coords, 0, voxelgrid_colors.shape[0] - 1)
# Trilinear interpolation for colors
colors = self._interpolate_color(grid_coords, local_coords, voxelgrid_colors)
vertices = vertices * 2 - 1 # Normalize vertices to [-1, 1]
vertices /= 2.0 # Normalize vertices to [-0.5, 0.5]
# Compute deviation loss for training
deviation_loss = torch.tensor(0.0, device=self.device)
if training and deform_field is not None:
# Compute deviation between original and deformed vertices
deviation_loss = self._compute_deviation_loss(vertices, deform_field)
# faces = faces.flip(dims=[1]) # Maintain consistent face orientation
return vertices, faces, deviation_loss, colors
def _interpolate_color(self, grid_coords: torch.Tensor,
local_coords: torch.Tensor,
color_field: torch.Tensor) -> torch.Tensor:
"""
Interpolate colors using trilinear interpolation
Args:
grid_coords: (N, 3) integer grid coordinates
local_coords: (N, 3) fractional positions within grid cells
color_field: (res, res, res, C) color values
"""
x, y, z = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]
# Get corner values for each vertex
c000 = color_field[grid_coords[:, 0], grid_coords[:, 1], grid_coords[:, 2]]
c001 = color_field[grid_coords[:, 0], grid_coords[:, 1],
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
c010 = color_field[grid_coords[:, 0],
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
grid_coords[:, 2]]
c011 = color_field[grid_coords[:, 0],
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
c100 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
grid_coords[:, 1], grid_coords[:, 2]]
c101 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
grid_coords[:, 1],
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
c110 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
grid_coords[:, 2]]
c111 = color_field[torch.clamp(grid_coords[:, 0] + 1, 0, color_field.shape[0] - 1),
torch.clamp(grid_coords[:, 1] + 1, 0, color_field.shape[1] - 1),
torch.clamp(grid_coords[:, 2] + 1, 0, color_field.shape[2] - 1)]
# Interpolate along x axis
c00 = c000 * (1 - x)[:, None] + c100 * x[:, None]
c01 = c001 * (1 - x)[:, None] + c101 * x[:, None]
c10 = c010 * (1 - x)[:, None] + c110 * x[:, None]
c11 = c011 * (1 - x)[:, None] + c111 * x[:, None]
# Interpolate along y axis
c0 = c00 * (1 - y)[:, None] + c10 * y[:, None]
c1 = c01 * (1 - y)[:, None] + c11 * y[:, None]
# Interpolate along z axis
colors = c0 * (1 - z)[:, None] + c1 * z[:, None]
return colors
def _compute_deviation_loss(self, vertices: torch.Tensor,
deform_field: torch.Tensor) -> torch.Tensor:
"""Compute deviation loss for training"""
# Since DiffDMC already handles the deformation, we compute the loss
# based on the magnitude of the deformation field
return torch.mean(deform_field ** 2)
class SparseFeatures2MCMesh:
def __init__(self, device="cuda", res=128, use_color=True):
super().__init__()
self.device = device
self.res = res
self.mesh_extractor = EnhancedMarchingCubes(device=device)
self.sdf_bias = -1.0 / res
verts, cube = construct_dense_grid(self.res, self.device)
self.reg_c = cube.to(self.device)
self.reg_v = verts.to(self.device)
self.use_color = use_color
self._calc_layout()
def _calc_layout(self):
LAYOUTS = {
'sdf': {'shape': (8, 1), 'size': 8},
'deform': {'shape': (8, 3), 'size': 8 * 3},
'weights': {'shape': (21,), 'size': 21}
}
if self.use_color:
'''
6 channel color including normal map
'''
LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6}
self.layouts = edict(LAYOUTS)
start = 0
for k, v in self.layouts.items():
v['range'] = (start, start + v['size'])
start += v['size']
self.feats_channels = start
def get_layout(self, feats: torch.Tensor, name: str):
if name not in self.layouts:
return None
return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name][
'shape'])
def __call__(self, cubefeats: SparseTensor, training=False):
coords = cubefeats.coords[:, 1:]
feats = cubefeats.feats
sdf, deform, color, weights = [self.get_layout(feats, name)
for name in ['sdf', 'deform', 'color', 'weights']]
sdf += self.sdf_bias
v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform]
v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1),
training=training)
v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res + 1, sdf_init=True)
if self.use_color:
sdf_d, deform_d, colors_d = (v_attrs_d[..., 0], v_attrs_d[..., 1:4],
v_attrs_d[..., 4:])
else:
sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4]
colors_d = None
x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res)
vertices, faces, L_dev, colors = self.mesh_extractor(
voxelgrid_vertices=x_nx3,
scalar_field=sdf_d,
voxelgrid_colors=colors_d,
training=training
)
mesh = MeshExtractResult(vertices=vertices, faces=faces,
vertex_attrs=colors, res=self.res)
if training:
if mesh.success:
reg_loss += L_dev.mean() * 0.5
reg_loss += (weights[:, :20]).abs().mean() * 0.2
mesh.reg_loss = reg_loss
mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res)
mesh.tsdf_s = v_attrs[:, 0]
return mesh