Spaces:
Runtime error
Runtime error
| # Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Dict, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from ...configuration_utils import ConfigMixin, register_to_config | |
| from ...models import ModelMixin | |
| from ...utils import BaseOutput | |
| from .camera import create_pan_cameras | |
| def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: | |
| r""" | |
| Sample from the given discrete probability distribution with replacement. | |
| The i-th bin is assumed to have mass pmf[i]. | |
| Args: | |
| pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() | |
| n_samples: number of samples | |
| Return: | |
| indices sampled with replacement | |
| """ | |
| *shape, support_size, last_dim = pmf.shape | |
| assert last_dim == 1 | |
| cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) | |
| inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) | |
| return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) | |
| def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: | |
| """ | |
| Concatenate x and its positional encodings, following NeRF. | |
| Reference: https://arxiv.org/pdf/2210.04628.pdf | |
| """ | |
| if min_deg == max_deg: | |
| return x | |
| scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) | |
| *shape, dim = x.shape | |
| xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) | |
| assert xb.shape[-1] == dim * (max_deg - min_deg) | |
| emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() | |
| return torch.cat([x, emb], dim=-1) | |
| def encode_position(position): | |
| return posenc_nerf(position, min_deg=0, max_deg=15) | |
| def encode_direction(position, direction=None): | |
| if direction is None: | |
| return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) | |
| else: | |
| return posenc_nerf(direction, min_deg=0, max_deg=8) | |
| def _sanitize_name(x: str) -> str: | |
| return x.replace(".", "__") | |
| def integrate_samples(volume_range, ts, density, channels): | |
| r""" | |
| Function integrating the model output. | |
| Args: | |
| volume_range: Specifies the integral range [t0, t1] | |
| ts: timesteps | |
| density: torch.Tensor [batch_size, *shape, n_samples, 1] | |
| channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] | |
| returns: | |
| channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density | |
| *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume | |
| ) | |
| """ | |
| # 1. Calculate the weights | |
| _, _, dt = volume_range.partition(ts) | |
| ddensity = density * dt | |
| mass = torch.cumsum(ddensity, dim=-2) | |
| transmittance = torch.exp(-mass[..., -1, :]) | |
| alphas = 1.0 - torch.exp(-ddensity) | |
| Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) | |
| # This is the probability of light hitting and reflecting off of | |
| # something at depth [..., i, :]. | |
| weights = alphas * Ts | |
| # 2. Integrate channels | |
| channels = torch.sum(channels * weights, dim=-2) | |
| return channels, weights, transmittance | |
| def volume_query_points(volume, grid_size): | |
| indices = torch.arange(grid_size**3, device=volume.bbox_min.device) | |
| zs = indices % grid_size | |
| ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size | |
| xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size | |
| combined = torch.stack([xs, ys, zs], dim=1) | |
| return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min | |
| def _convert_srgb_to_linear(u: torch.Tensor): | |
| return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) | |
| def _create_flat_edge_indices( | |
| flat_cube_indices: torch.Tensor, | |
| grid_size: Tuple[int, int, int], | |
| ): | |
| num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] | |
| y_offset = num_xs | |
| num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] | |
| z_offset = num_xs + num_ys | |
| return torch.stack( | |
| [ | |
| # Edges spanning x-axis. | |
| flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
| + flat_cube_indices[:, 1] * grid_size[2] | |
| + flat_cube_indices[:, 2], | |
| flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
| + (flat_cube_indices[:, 1] + 1) * grid_size[2] | |
| + flat_cube_indices[:, 2], | |
| flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
| + flat_cube_indices[:, 1] * grid_size[2] | |
| + flat_cube_indices[:, 2] | |
| + 1, | |
| flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
| + (flat_cube_indices[:, 1] + 1) * grid_size[2] | |
| + flat_cube_indices[:, 2] | |
| + 1, | |
| # Edges spanning y-axis. | |
| ( | |
| y_offset | |
| + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] | |
| + flat_cube_indices[:, 1] * grid_size[2] | |
| + flat_cube_indices[:, 2] | |
| ), | |
| ( | |
| y_offset | |
| + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] | |
| + flat_cube_indices[:, 1] * grid_size[2] | |
| + flat_cube_indices[:, 2] | |
| ), | |
| ( | |
| y_offset | |
| + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] | |
| + flat_cube_indices[:, 1] * grid_size[2] | |
| + flat_cube_indices[:, 2] | |
| + 1 | |
| ), | |
| ( | |
| y_offset | |
| + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] | |
| + flat_cube_indices[:, 1] * grid_size[2] | |
| + flat_cube_indices[:, 2] | |
| + 1 | |
| ), | |
| # Edges spanning z-axis. | |
| ( | |
| z_offset | |
| + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) | |
| + flat_cube_indices[:, 1] * (grid_size[2] - 1) | |
| + flat_cube_indices[:, 2] | |
| ), | |
| ( | |
| z_offset | |
| + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) | |
| + flat_cube_indices[:, 1] * (grid_size[2] - 1) | |
| + flat_cube_indices[:, 2] | |
| ), | |
| ( | |
| z_offset | |
| + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) | |
| + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) | |
| + flat_cube_indices[:, 2] | |
| ), | |
| ( | |
| z_offset | |
| + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) | |
| + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) | |
| + flat_cube_indices[:, 2] | |
| ), | |
| ], | |
| dim=-1, | |
| ) | |
| class VoidNeRFModel(nn.Module): | |
| """ | |
| Implements the default empty space model where all queries are rendered as background. | |
| """ | |
| def __init__(self, background, channel_scale=255.0): | |
| super().__init__() | |
| background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) | |
| self.register_buffer("background", background) | |
| def forward(self, position): | |
| background = self.background[None].to(position.device) | |
| shape = position.shape[:-1] | |
| ones = [1] * (len(shape) - 1) | |
| n_channels = background.shape[-1] | |
| background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) | |
| return background | |
| class VolumeRange: | |
| t0: torch.Tensor | |
| t1: torch.Tensor | |
| intersected: torch.Tensor | |
| def __post_init__(self): | |
| assert self.t0.shape == self.t1.shape == self.intersected.shape | |
| def partition(self, ts): | |
| """ | |
| Partitions t0 and t1 into n_samples intervals. | |
| Args: | |
| ts: [batch_size, *shape, n_samples, 1] | |
| Return: | |
| lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, | |
| *shape, n_samples, 1] | |
| where | |
| ts \\in [lower, upper] deltas = upper - lower | |
| """ | |
| mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 | |
| lower = torch.cat([self.t0[..., None, :], mids], dim=-2) | |
| upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) | |
| delta = upper - lower | |
| assert lower.shape == upper.shape == delta.shape == ts.shape | |
| return lower, upper, delta | |
| class BoundingBoxVolume(nn.Module): | |
| """ | |
| Axis-aligned bounding box defined by the two opposite corners. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| bbox_min, | |
| bbox_max, | |
| min_dist: float = 0.0, | |
| min_t_range: float = 1e-3, | |
| ): | |
| """ | |
| Args: | |
| bbox_min: the left/bottommost corner of the bounding box | |
| bbox_max: the other corner of the bounding box | |
| min_dist: all rays should start at least this distance away from the origin. | |
| """ | |
| super().__init__() | |
| self.min_dist = min_dist | |
| self.min_t_range = min_t_range | |
| self.bbox_min = torch.tensor(bbox_min) | |
| self.bbox_max = torch.tensor(bbox_max) | |
| self.bbox = torch.stack([self.bbox_min, self.bbox_max]) | |
| assert self.bbox.shape == (2, 3) | |
| assert min_dist >= 0.0 | |
| assert min_t_range > 0.0 | |
| def intersect( | |
| self, | |
| origin: torch.Tensor, | |
| direction: torch.Tensor, | |
| t0_lower: Optional[torch.Tensor] = None, | |
| epsilon=1e-6, | |
| ): | |
| """ | |
| Args: | |
| origin: [batch_size, *shape, 3] | |
| direction: [batch_size, *shape, 3] | |
| t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. | |
| params: Optional meta parameters in case Volume is parametric | |
| epsilon: to stabilize calculations | |
| Return: | |
| A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with | |
| the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to | |
| be on the boundary of the volume. | |
| """ | |
| batch_size, *shape, _ = origin.shape | |
| ones = [1] * len(shape) | |
| bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) | |
| def _safe_divide(a, b, epsilon=1e-6): | |
| return a / torch.where(b < 0, b - epsilon, b + epsilon) | |
| ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) | |
| # Cases to think about: | |
| # | |
| # 1. t1 <= t0: the ray does not pass through the AABB. | |
| # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. | |
| # 3. t0 <= 0 <= t1: the ray starts from inside the BB | |
| # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. | |
| # | |
| # 1 and 4 are clearly handled from t0 < t1 below. | |
| # Making t0 at least min_dist (>= 0) takes care of 2 and 3. | |
| t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) | |
| t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values | |
| assert t0.shape == t1.shape == (batch_size, *shape, 1) | |
| if t0_lower is not None: | |
| assert t0.shape == t0_lower.shape | |
| t0 = torch.maximum(t0, t0_lower) | |
| intersected = t0 + self.min_t_range < t1 | |
| t0 = torch.where(intersected, t0, torch.zeros_like(t0)) | |
| t1 = torch.where(intersected, t1, torch.ones_like(t1)) | |
| return VolumeRange(t0=t0, t1=t1, intersected=intersected) | |
| class StratifiedRaySampler(nn.Module): | |
| """ | |
| Instead of fixed intervals, a sample is drawn uniformly at random from each interval. | |
| """ | |
| def __init__(self, depth_mode: str = "linear"): | |
| """ | |
| :param depth_mode: linear samples ts linearly in depth. harmonic ensures | |
| closer points are sampled more densely. | |
| """ | |
| self.depth_mode = depth_mode | |
| assert self.depth_mode in ("linear", "geometric", "harmonic") | |
| def sample( | |
| self, | |
| t0: torch.Tensor, | |
| t1: torch.Tensor, | |
| n_samples: int, | |
| epsilon: float = 1e-3, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| t0: start time has shape [batch_size, *shape, 1] | |
| t1: finish time has shape [batch_size, *shape, 1] | |
| n_samples: number of ts to sample | |
| Return: | |
| sampled ts of shape [batch_size, *shape, n_samples, 1] | |
| """ | |
| ones = [1] * (len(t0.shape) - 1) | |
| ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) | |
| if self.depth_mode == "linear": | |
| ts = t0 * (1.0 - ts) + t1 * ts | |
| elif self.depth_mode == "geometric": | |
| ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() | |
| elif self.depth_mode == "harmonic": | |
| # The original NeRF recommends this interpolation scheme for | |
| # spherical scenes, but there could be some weird edge cases when | |
| # the observer crosses from the inner to outer volume. | |
| ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) | |
| mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) | |
| upper = torch.cat([mids, t1], dim=-1) | |
| lower = torch.cat([t0, mids], dim=-1) | |
| # yiyi notes: add a random seed here for testing, don't forget to remove | |
| torch.manual_seed(0) | |
| t_rand = torch.rand_like(ts) | |
| ts = lower + (upper - lower) * t_rand | |
| return ts.unsqueeze(-1) | |
| class ImportanceRaySampler(nn.Module): | |
| """ | |
| Given the initial estimate of densities, this samples more from regions/bins expected to have objects. | |
| """ | |
| def __init__( | |
| self, | |
| volume_range: VolumeRange, | |
| ts: torch.Tensor, | |
| weights: torch.Tensor, | |
| blur_pool: bool = False, | |
| alpha: float = 1e-5, | |
| ): | |
| """ | |
| Args: | |
| volume_range: the range in which a ray intersects the given volume. | |
| ts: earlier samples from the coarse rendering step | |
| weights: discretized version of density * transmittance | |
| blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. | |
| alpha: small value to add to weights. | |
| """ | |
| self.volume_range = volume_range | |
| self.ts = ts.clone().detach() | |
| self.weights = weights.clone().detach() | |
| self.blur_pool = blur_pool | |
| self.alpha = alpha | |
| def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: | |
| """ | |
| Args: | |
| t0: start time has shape [batch_size, *shape, 1] | |
| t1: finish time has shape [batch_size, *shape, 1] | |
| n_samples: number of ts to sample | |
| Return: | |
| sampled ts of shape [batch_size, *shape, n_samples, 1] | |
| """ | |
| lower, upper, _ = self.volume_range.partition(self.ts) | |
| batch_size, *shape, n_coarse_samples, _ = self.ts.shape | |
| weights = self.weights | |
| if self.blur_pool: | |
| padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) | |
| maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) | |
| weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) | |
| weights = weights + self.alpha | |
| pmf = weights / weights.sum(dim=-2, keepdim=True) | |
| inds = sample_pmf(pmf, n_samples) | |
| assert inds.shape == (batch_size, *shape, n_samples, 1) | |
| assert (inds >= 0).all() and (inds < n_coarse_samples).all() | |
| t_rand = torch.rand(inds.shape, device=inds.device) | |
| lower_ = torch.gather(lower, -2, inds) | |
| upper_ = torch.gather(upper, -2, inds) | |
| ts = lower_ + (upper_ - lower_) * t_rand | |
| ts = torch.sort(ts, dim=-2).values | |
| return ts | |
| class MeshDecoderOutput(BaseOutput): | |
| """ | |
| A 3D triangle mesh with optional data at the vertices and faces. | |
| Args: | |
| verts (`torch.Tensor` of shape `(N, 3)`): | |
| array of vertext coordinates | |
| faces (`torch.Tensor` of shape `(N, 3)`): | |
| array of triangles, pointing to indices in verts. | |
| vertext_channels (Dict): | |
| vertext coordinates for each color channel | |
| """ | |
| verts: torch.Tensor | |
| faces: torch.Tensor | |
| vertex_channels: Dict[str, torch.Tensor] | |
| class MeshDecoder(nn.Module): | |
| """ | |
| Construct meshes from Signed distance functions (SDFs) using marching cubes method | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| cases = torch.zeros(256, 5, 3, dtype=torch.long) | |
| masks = torch.zeros(256, 5, dtype=torch.bool) | |
| self.register_buffer("cases", cases) | |
| self.register_buffer("masks", masks) | |
| def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): | |
| """ | |
| For a signed distance field, produce a mesh using marching cubes. | |
| :param field: a 3D tensor of field values, where negative values correspond | |
| to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. | |
| :param min_point: a tensor of shape [3] containing the point corresponding | |
| to (0, 0, 0) in the field. | |
| :param size: a tensor of shape [3] containing the per-axis distance from the | |
| (0, 0, 0) field corner and the (-1, -1, -1) field corner. | |
| """ | |
| assert len(field.shape) == 3, "input must be a 3D scalar field" | |
| dev = field.device | |
| cases = self.cases.to(dev) | |
| masks = self.masks.to(dev) | |
| min_point = min_point.to(dev) | |
| size = size.to(dev) | |
| grid_size = field.shape | |
| grid_size_tensor = torch.tensor(grid_size).to(size) | |
| # Create bitmasks between 0 and 255 (inclusive) indicating the state | |
| # of the eight corners of each cube. | |
| bitmasks = (field > 0).to(torch.uint8) | |
| bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) | |
| bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) | |
| bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) | |
| # Compute corner coordinates across the entire grid. | |
| corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) | |
| corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ | |
| :, None, None | |
| ] | |
| corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ | |
| :, None | |
| ] | |
| corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) | |
| # Compute all vertices across all edges in the grid, even though we will | |
| # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. | |
| # These are all midpoints, and don't account for interpolation (which is | |
| # done later based on the used edge midpoints). | |
| edge_midpoints = torch.cat( | |
| [ | |
| ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), | |
| ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), | |
| ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), | |
| ], | |
| dim=0, | |
| ) | |
| # Create a flat array of [X, Y, Z] indices for each cube. | |
| cube_indices = torch.zeros( | |
| grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long | |
| ) | |
| cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] | |
| cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] | |
| cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) | |
| flat_cube_indices = cube_indices.reshape(-1, 3) | |
| # Create a flat array mapping each cube to 12 global edge indices. | |
| edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) | |
| # Apply the LUT to figure out the triangles. | |
| flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask | |
| local_tris = cases[flat_bitmasks] | |
| local_masks = masks[flat_bitmasks] | |
| # Compute the global edge indices for the triangles. | |
| global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( | |
| local_tris.shape | |
| ) | |
| # Select the used triangles for each cube. | |
| selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] | |
| # Now we have a bunch of indices into the full list of possible vertices, | |
| # but we want to reduce this list to only the used vertices. | |
| used_vertex_indices = torch.unique(selected_tris.view(-1)) | |
| used_edge_midpoints = edge_midpoints[used_vertex_indices] | |
| old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) | |
| old_index_to_new_index[used_vertex_indices] = torch.arange( | |
| len(used_vertex_indices), device=dev, dtype=torch.long | |
| ) | |
| # Rewrite the triangles to use the new indices | |
| faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) | |
| # Compute the actual interpolated coordinates corresponding to edge midpoints. | |
| v1 = torch.floor(used_edge_midpoints).to(torch.long) | |
| v2 = torch.ceil(used_edge_midpoints).to(torch.long) | |
| s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] | |
| s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] | |
| p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point | |
| p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point | |
| # The signs of s1 and s2 should be different. We want to find | |
| # t such that t*s2 + (1-t)*s1 = 0. | |
| t = (s1 / (s1 - s2))[:, None] | |
| verts = t * p2 + (1 - t) * p1 | |
| return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) | |
| class MLPNeRFModelOutput(BaseOutput): | |
| density: torch.Tensor | |
| signed_distance: torch.Tensor | |
| channels: torch.Tensor | |
| ts: torch.Tensor | |
| class MLPNeRSTFModel(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| d_hidden: int = 256, | |
| n_output: int = 12, | |
| n_hidden_layers: int = 6, | |
| act_fn: str = "swish", | |
| insert_direction_at: int = 4, | |
| ): | |
| super().__init__() | |
| # Instantiate the MLP | |
| # Find out the dimension of encoded position and direction | |
| dummy = torch.eye(1, 3) | |
| d_posenc_pos = encode_position(position=dummy).shape[-1] | |
| d_posenc_dir = encode_direction(position=dummy).shape[-1] | |
| mlp_widths = [d_hidden] * n_hidden_layers | |
| input_widths = [d_posenc_pos] + mlp_widths | |
| output_widths = mlp_widths + [n_output] | |
| if insert_direction_at is not None: | |
| input_widths[insert_direction_at] += d_posenc_dir | |
| self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) | |
| if act_fn == "swish": | |
| # self.activation = swish | |
| # yiyi testing: | |
| self.activation = lambda x: F.silu(x) | |
| else: | |
| raise ValueError(f"Unsupported activation function {act_fn}") | |
| self.sdf_activation = torch.tanh | |
| self.density_activation = torch.nn.functional.relu | |
| self.channel_activation = torch.sigmoid | |
| def map_indices_to_keys(self, output): | |
| h_map = { | |
| "sdf": (0, 1), | |
| "density_coarse": (1, 2), | |
| "density_fine": (2, 3), | |
| "stf": (3, 6), | |
| "nerf_coarse": (6, 9), | |
| "nerf_fine": (9, 12), | |
| } | |
| mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} | |
| return mapped_output | |
| def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): | |
| h = encode_position(position) | |
| h_preact = h | |
| h_directionless = None | |
| for i, layer in enumerate(self.mlp): | |
| if i == self.config.insert_direction_at: # 4 in the config | |
| h_directionless = h_preact | |
| h_direction = encode_direction(position, direction=direction) | |
| h = torch.cat([h, h_direction], dim=-1) | |
| h = layer(h) | |
| h_preact = h | |
| if i < len(self.mlp) - 1: | |
| h = self.activation(h) | |
| h_final = h | |
| if h_directionless is None: | |
| h_directionless = h_preact | |
| activation = self.map_indices_to_keys(h_final) | |
| if nerf_level == "coarse": | |
| h_density = activation["density_coarse"] | |
| else: | |
| h_density = activation["density_fine"] | |
| if rendering_mode == "nerf": | |
| if nerf_level == "coarse": | |
| h_channels = activation["nerf_coarse"] | |
| else: | |
| h_channels = activation["nerf_fine"] | |
| elif rendering_mode == "stf": | |
| h_channels = activation["stf"] | |
| density = self.density_activation(h_density) | |
| signed_distance = self.sdf_activation(activation["sdf"]) | |
| channels = self.channel_activation(h_channels) | |
| # yiyi notes: I think signed_distance is not used | |
| return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) | |
| class ChannelsProj(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| vectors: int, | |
| channels: int, | |
| d_latent: int, | |
| ): | |
| super().__init__() | |
| self.proj = nn.Linear(d_latent, vectors * channels) | |
| self.norm = nn.LayerNorm(channels) | |
| self.d_latent = d_latent | |
| self.vectors = vectors | |
| self.channels = channels | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x_bvd = x | |
| w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) | |
| b_vc = self.proj.bias.view(1, self.vectors, self.channels) | |
| h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) | |
| h = self.norm(h) | |
| h = h + b_vc | |
| return h | |
| class ShapEParamsProjModel(ModelMixin, ConfigMixin): | |
| """ | |
| project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). | |
| For more details, see the original paper: | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| param_names: Tuple[str] = ( | |
| "nerstf.mlp.0.weight", | |
| "nerstf.mlp.1.weight", | |
| "nerstf.mlp.2.weight", | |
| "nerstf.mlp.3.weight", | |
| ), | |
| param_shapes: Tuple[Tuple[int]] = ( | |
| (256, 93), | |
| (256, 256), | |
| (256, 256), | |
| (256, 256), | |
| ), | |
| d_latent: int = 1024, | |
| ): | |
| super().__init__() | |
| # check inputs | |
| if len(param_names) != len(param_shapes): | |
| raise ValueError("Must provide same number of `param_names` as `param_shapes`") | |
| self.projections = nn.ModuleDict({}) | |
| for k, (vectors, channels) in zip(param_names, param_shapes): | |
| self.projections[_sanitize_name(k)] = ChannelsProj( | |
| vectors=vectors, | |
| channels=channels, | |
| d_latent=d_latent, | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| out = {} | |
| start = 0 | |
| for k, shape in zip(self.config.param_names, self.config.param_shapes): | |
| vectors, _ = shape | |
| end = start + vectors | |
| x_bvd = x[:, start:end] | |
| out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) | |
| start = end | |
| return out | |
| class ShapERenderer(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| *, | |
| param_names: Tuple[str] = ( | |
| "nerstf.mlp.0.weight", | |
| "nerstf.mlp.1.weight", | |
| "nerstf.mlp.2.weight", | |
| "nerstf.mlp.3.weight", | |
| ), | |
| param_shapes: Tuple[Tuple[int]] = ( | |
| (256, 93), | |
| (256, 256), | |
| (256, 256), | |
| (256, 256), | |
| ), | |
| d_latent: int = 1024, | |
| d_hidden: int = 256, | |
| n_output: int = 12, | |
| n_hidden_layers: int = 6, | |
| act_fn: str = "swish", | |
| insert_direction_at: int = 4, | |
| background: Tuple[float] = ( | |
| 255.0, | |
| 255.0, | |
| 255.0, | |
| ), | |
| ): | |
| super().__init__() | |
| self.params_proj = ShapEParamsProjModel( | |
| param_names=param_names, | |
| param_shapes=param_shapes, | |
| d_latent=d_latent, | |
| ) | |
| self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) | |
| self.void = VoidNeRFModel(background=background, channel_scale=255.0) | |
| self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) | |
| self.mesh_decoder = MeshDecoder() | |
| def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): | |
| """ | |
| Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below | |
| with some abuse of notations) | |
| C(r) := sum( | |
| transmittance(t[i]) * integrate( | |
| lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], | |
| ) for i in range(len(parts)) | |
| ) + transmittance(t[-1]) * void_model(t[-1]).channels | |
| where | |
| 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through | |
| the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are | |
| obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t | |
| where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the | |
| shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and | |
| transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], | |
| math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). | |
| args: | |
| rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: | |
| number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including | |
| :return: A tuple of | |
| - `channels` | |
| - A importance samplers for additional fine-grained rendering | |
| - raw model output | |
| """ | |
| origin, direction = rays[..., 0, :], rays[..., 1, :] | |
| # Integrate over [t[i], t[i + 1]] | |
| # 1 Intersect the rays with the current volume and sample ts to integrate along. | |
| vrange = self.volume.intersect(origin, direction, t0_lower=None) | |
| ts = sampler.sample(vrange.t0, vrange.t1, n_samples) | |
| ts = ts.to(rays.dtype) | |
| if prev_model_out is not None: | |
| # Append the previous ts now before fprop because previous | |
| # rendering used a different model and we can't reuse the output. | |
| ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values | |
| batch_size, *_shape, _t0_dim = vrange.t0.shape | |
| _, *ts_shape, _ts_dim = ts.shape | |
| # 2. Get the points along the ray and query the model | |
| directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) | |
| positions = origin.unsqueeze(-2) + ts * directions | |
| directions = directions.to(self.mlp.dtype) | |
| positions = positions.to(self.mlp.dtype) | |
| optional_directions = directions if render_with_direction else None | |
| model_out = self.mlp( | |
| position=positions, | |
| direction=optional_directions, | |
| ts=ts, | |
| nerf_level="coarse" if prev_model_out is None else "fine", | |
| ) | |
| # 3. Integrate the model results | |
| channels, weights, transmittance = integrate_samples( | |
| vrange, model_out.ts, model_out.density, model_out.channels | |
| ) | |
| # 4. Clean up results that do not intersect with the volume. | |
| transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) | |
| channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) | |
| # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). | |
| channels = channels + transmittance * self.void(origin) | |
| weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) | |
| return channels, weighted_sampler, model_out | |
| def decode_to_image( | |
| self, | |
| latents, | |
| device, | |
| size: int = 64, | |
| ray_batch_size: int = 4096, | |
| n_coarse_samples=64, | |
| n_fine_samples=128, | |
| ): | |
| # project the parameters from the generated latents | |
| projected_params = self.params_proj(latents) | |
| # update the mlp layers of the renderer | |
| for name, param in self.mlp.state_dict().items(): | |
| if f"nerstf.{name}" in projected_params.keys(): | |
| param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) | |
| # create cameras object | |
| camera = create_pan_cameras(size) | |
| rays = camera.camera_rays | |
| rays = rays.to(device) | |
| n_batches = rays.shape[1] // ray_batch_size | |
| coarse_sampler = StratifiedRaySampler() | |
| images = [] | |
| for idx in range(n_batches): | |
| rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] | |
| # render rays with coarse, stratified samples. | |
| _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) | |
| # Then, render with additional importance-weighted ray samples. | |
| channels, _, _ = self.render_rays( | |
| rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out | |
| ) | |
| images.append(channels) | |
| images = torch.cat(images, dim=1) | |
| images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) | |
| return images | |
| def decode_to_mesh( | |
| self, | |
| latents, | |
| device, | |
| grid_size: int = 128, | |
| query_batch_size: int = 4096, | |
| texture_channels: Tuple = ("R", "G", "B"), | |
| ): | |
| # 1. project the parameters from the generated latents | |
| projected_params = self.params_proj(latents) | |
| # 2. update the mlp layers of the renderer | |
| for name, param in self.mlp.state_dict().items(): | |
| if f"nerstf.{name}" in projected_params.keys(): | |
| param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) | |
| # 3. decoding with STF rendering | |
| # 3.1 query the SDF values at vertices along a regular 128**3 grid | |
| query_points = volume_query_points(self.volume, grid_size) | |
| query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) | |
| fields = [] | |
| for idx in range(0, query_positions.shape[1], query_batch_size): | |
| query_batch = query_positions[:, idx : idx + query_batch_size] | |
| model_out = self.mlp( | |
| position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" | |
| ) | |
| fields.append(model_out.signed_distance) | |
| # predicted SDF values | |
| fields = torch.cat(fields, dim=1) | |
| fields = fields.float() | |
| assert ( | |
| len(fields.shape) == 3 and fields.shape[-1] == 1 | |
| ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" | |
| fields = fields.reshape(1, *([grid_size] * 3)) | |
| # create grid 128 x 128 x 128 | |
| # - force a negative border around the SDFs to close off all the models. | |
| full_grid = torch.zeros( | |
| 1, | |
| grid_size + 2, | |
| grid_size + 2, | |
| grid_size + 2, | |
| device=fields.device, | |
| dtype=fields.dtype, | |
| ) | |
| full_grid.fill_(-1.0) | |
| full_grid[:, 1:-1, 1:-1, 1:-1] = fields | |
| fields = full_grid | |
| # apply a differentiable implementation of Marching Cubes to construct meshs | |
| raw_meshes = [] | |
| mesh_mask = [] | |
| for field in fields: | |
| raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) | |
| mesh_mask.append(True) | |
| raw_meshes.append(raw_mesh) | |
| mesh_mask = torch.tensor(mesh_mask, device=fields.device) | |
| max_vertices = max(len(m.verts) for m in raw_meshes) | |
| # 3.2. query the texture color head at each vertex of the resulting mesh. | |
| texture_query_positions = torch.stack( | |
| [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], | |
| dim=0, | |
| ) | |
| texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) | |
| textures = [] | |
| for idx in range(0, texture_query_positions.shape[1], query_batch_size): | |
| query_batch = texture_query_positions[:, idx : idx + query_batch_size] | |
| texture_model_out = self.mlp( | |
| position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" | |
| ) | |
| textures.append(texture_model_out.channels) | |
| # predict texture color | |
| textures = torch.cat(textures, dim=1) | |
| textures = _convert_srgb_to_linear(textures) | |
| textures = textures.float() | |
| # 3.3 augument the mesh with texture data | |
| assert len(textures.shape) == 3 and textures.shape[-1] == len( | |
| texture_channels | |
| ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" | |
| for m, texture in zip(raw_meshes, textures): | |
| texture = texture[: len(m.verts)] | |
| m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) | |
| return raw_meshes[0] | |