|  | from dataclasses import dataclass | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | from torch import distributed as tdist | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  | import math | 
					
						
						|  | import mcubes | 
					
						
						|  | import numpy as np | 
					
						
						|  | from einops import repeat, rearrange | 
					
						
						|  | from skimage import measure | 
					
						
						|  |  | 
					
						
						|  | from craftsman.utils.base import BaseModule | 
					
						
						|  | from craftsman.utils.typing import * | 
					
						
						|  | from craftsman.utils.misc import get_world_size | 
					
						
						|  | from craftsman.utils.ops import generate_dense_grid_points | 
					
						
						|  |  | 
					
						
						|  | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] | 
					
						
						|  |  | 
					
						
						|  | class FourierEmbedder(nn.Module): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | num_freqs: int = 6, | 
					
						
						|  | logspace: bool = True, | 
					
						
						|  | input_dim: int = 3, | 
					
						
						|  | include_input: bool = True, | 
					
						
						|  | include_pi: bool = True) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | if logspace: | 
					
						
						|  | frequencies = 2.0 ** torch.arange( | 
					
						
						|  | num_freqs, | 
					
						
						|  | dtype=torch.float32 | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | frequencies = torch.linspace( | 
					
						
						|  | 1.0, | 
					
						
						|  | 2.0 ** (num_freqs - 1), | 
					
						
						|  | num_freqs, | 
					
						
						|  | dtype=torch.float32 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if include_pi: | 
					
						
						|  | frequencies *= torch.pi | 
					
						
						|  |  | 
					
						
						|  | self.register_buffer("frequencies", frequencies, persistent=False) | 
					
						
						|  | self.include_input = include_input | 
					
						
						|  | self.num_freqs = num_freqs | 
					
						
						|  |  | 
					
						
						|  | self.out_dim = self.get_dims(input_dim) | 
					
						
						|  |  | 
					
						
						|  | def get_dims(self, input_dim): | 
					
						
						|  | temp = 1 if self.include_input or self.num_freqs == 0 else 0 | 
					
						
						|  | out_dim = input_dim * (self.num_freqs * 2 + temp) | 
					
						
						|  |  | 
					
						
						|  | return out_dim | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | if self.num_freqs > 0: | 
					
						
						|  | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) | 
					
						
						|  | if self.include_input: | 
					
						
						|  | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | return torch.cat((embed.sin(), embed.cos()), dim=-1) | 
					
						
						|  | else: | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LearnedFourierEmbedder(nn.Module): | 
					
						
						|  | def __init__(self, input_dim, dim): | 
					
						
						|  | super().__init__() | 
					
						
						|  | assert (dim % 2) == 0 | 
					
						
						|  | half_dim = dim // 2 | 
					
						
						|  | per_channel_dim = half_dim // input_dim | 
					
						
						|  | self.weights = nn.Parameter(torch.randn(per_channel_dim)) | 
					
						
						|  |  | 
					
						
						|  | self.out_dim = self.get_dims(input_dim) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  |  | 
					
						
						|  | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) | 
					
						
						|  | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) | 
					
						
						|  | return fouriered | 
					
						
						|  |  | 
					
						
						|  | def get_dims(self, input_dim): | 
					
						
						|  | return input_dim * (self.weights.shape[0] * 2 + 1) | 
					
						
						|  |  | 
					
						
						|  | class Sine(nn.Module): | 
					
						
						|  | def __init__(self, w0 = 1.): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.w0 = w0 | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return torch.sin(self.w0 * x) | 
					
						
						|  |  | 
					
						
						|  | class Siren(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | in_dim, | 
					
						
						|  | out_dim, | 
					
						
						|  | w0 = 1., | 
					
						
						|  | c = 6., | 
					
						
						|  | is_first = False, | 
					
						
						|  | use_bias = True, | 
					
						
						|  | activation = None, | 
					
						
						|  | dropout = 0. | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.in_dim = in_dim | 
					
						
						|  | self.out_dim = out_dim | 
					
						
						|  | self.is_first = is_first | 
					
						
						|  |  | 
					
						
						|  | weight = torch.zeros(out_dim, in_dim) | 
					
						
						|  | bias = torch.zeros(out_dim) if use_bias else None | 
					
						
						|  | self.init_(weight, bias, c = c, w0 = w0) | 
					
						
						|  |  | 
					
						
						|  | self.weight = nn.Parameter(weight) | 
					
						
						|  | self.bias = nn.Parameter(bias) if use_bias else None | 
					
						
						|  | self.activation = Sine(w0) if activation is None else activation | 
					
						
						|  | self.dropout = nn.Dropout(dropout) | 
					
						
						|  |  | 
					
						
						|  | def init_(self, weight, bias, c, w0): | 
					
						
						|  | dim = self.in_dim | 
					
						
						|  |  | 
					
						
						|  | w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) | 
					
						
						|  | weight.uniform_(-w_std, w_std) | 
					
						
						|  |  | 
					
						
						|  | if bias is not None: | 
					
						
						|  | bias.uniform_(-w_std, w_std) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | out =  F.linear(x, self.weight, self.bias) | 
					
						
						|  | out = self.activation(out) | 
					
						
						|  | out = self.dropout(out) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): | 
					
						
						|  | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): | 
					
						
						|  | return nn.Identity(), input_dim | 
					
						
						|  |  | 
					
						
						|  | elif embed_type == "fourier": | 
					
						
						|  | embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) | 
					
						
						|  |  | 
					
						
						|  | elif embed_type == "learned_fourier": | 
					
						
						|  | embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) | 
					
						
						|  |  | 
					
						
						|  | elif embed_type == "siren": | 
					
						
						|  | embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim) | 
					
						
						|  |  | 
					
						
						|  | elif embed_type == "hashgrid": | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | elif embed_type == "sphere_harmonic": | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") | 
					
						
						|  | return embedder_obj | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AutoEncoder(BaseModule): | 
					
						
						|  | @dataclass | 
					
						
						|  | class Config(BaseModule.Config): | 
					
						
						|  | pretrained_model_name_or_path: str = "" | 
					
						
						|  | num_latents: int = 256 | 
					
						
						|  | embed_dim: int = 64 | 
					
						
						|  | width: int = 768 | 
					
						
						|  |  | 
					
						
						|  | cfg: Config | 
					
						
						|  |  | 
					
						
						|  | def configure(self) -> None: | 
					
						
						|  | super().configure() | 
					
						
						|  |  | 
					
						
						|  | def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def decode(self, z: torch.FloatTensor) -> torch.FloatTensor: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): | 
					
						
						|  | posterior = None | 
					
						
						|  | if self.cfg.embed_dim > 0: | 
					
						
						|  | moments = self.pre_kl(latents) | 
					
						
						|  | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) | 
					
						
						|  | if sample_posterior: | 
					
						
						|  | kl_embed = posterior.sample() | 
					
						
						|  | else: | 
					
						
						|  | kl_embed = posterior.mode() | 
					
						
						|  | else: | 
					
						
						|  | kl_embed = latents | 
					
						
						|  | return kl_embed, posterior | 
					
						
						|  |  | 
					
						
						|  | def forward(self, | 
					
						
						|  | surface: torch.FloatTensor, | 
					
						
						|  | queries: torch.FloatTensor, | 
					
						
						|  | sample_posterior: bool = True): | 
					
						
						|  | shape_latents, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior) | 
					
						
						|  |  | 
					
						
						|  | latents = self.decode(kl_embed) | 
					
						
						|  |  | 
					
						
						|  | logits = self.query(queries, latents) | 
					
						
						|  |  | 
					
						
						|  | return shape_latents, latents, posterior, logits | 
					
						
						|  |  | 
					
						
						|  | def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def extract_geometry(self, | 
					
						
						|  | latents: torch.FloatTensor, | 
					
						
						|  | bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05), | 
					
						
						|  | octree_depth: int = 8, | 
					
						
						|  | num_chunks: int = 10000, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | if isinstance(bounds, float): | 
					
						
						|  | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] | 
					
						
						|  |  | 
					
						
						|  | bbox_min = np.array(bounds[0:3]) | 
					
						
						|  | bbox_max = np.array(bounds[3:6]) | 
					
						
						|  | bbox_size = bbox_max - bbox_min | 
					
						
						|  |  | 
					
						
						|  | xyz_samples, grid_size, length = generate_dense_grid_points( | 
					
						
						|  | bbox_min=bbox_min, | 
					
						
						|  | bbox_max=bbox_max, | 
					
						
						|  | octree_depth=octree_depth, | 
					
						
						|  | indexing="ij" | 
					
						
						|  | ) | 
					
						
						|  | xyz_samples = torch.FloatTensor(xyz_samples) | 
					
						
						|  | batch_size = latents.shape[0] | 
					
						
						|  |  | 
					
						
						|  | batch_logits = [] | 
					
						
						|  | for start in range(0, xyz_samples.shape[0], num_chunks): | 
					
						
						|  | queries = xyz_samples[start: start + num_chunks, :].to(latents) | 
					
						
						|  | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) | 
					
						
						|  |  | 
					
						
						|  | logits = self.query(batch_queries, latents) | 
					
						
						|  | batch_logits.append(logits.cpu()) | 
					
						
						|  |  | 
					
						
						|  | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy() | 
					
						
						|  |  | 
					
						
						|  | mesh_v_f = [] | 
					
						
						|  | has_surface = np.zeros((batch_size,), dtype=np.bool_) | 
					
						
						|  | for i in range(batch_size): | 
					
						
						|  | try: | 
					
						
						|  | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") | 
					
						
						|  |  | 
					
						
						|  | vertices = vertices / grid_size * bbox_size + bbox_min | 
					
						
						|  | faces = faces[:, [2, 1, 0]] | 
					
						
						|  | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) | 
					
						
						|  | has_surface[i] = True | 
					
						
						|  | except: | 
					
						
						|  | mesh_v_f.append((None, None)) | 
					
						
						|  | has_surface[i] = False | 
					
						
						|  |  | 
					
						
						|  | return mesh_v_f, has_surface | 
					
						
						|  |  | 
					
						
						|  | class DiagonalGaussianDistribution(object): | 
					
						
						|  | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): | 
					
						
						|  | self.feat_dim = feat_dim | 
					
						
						|  | self.parameters = parameters | 
					
						
						|  |  | 
					
						
						|  | if isinstance(parameters, list): | 
					
						
						|  | self.mean = parameters[0] | 
					
						
						|  | self.logvar = parameters[1] | 
					
						
						|  | else: | 
					
						
						|  | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) | 
					
						
						|  |  | 
					
						
						|  | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | 
					
						
						|  | self.deterministic = deterministic | 
					
						
						|  | self.std = torch.exp(0.5 * self.logvar) | 
					
						
						|  | self.var = torch.exp(self.logvar) | 
					
						
						|  | if self.deterministic: | 
					
						
						|  | self.var = self.std = torch.zeros_like(self.mean) | 
					
						
						|  |  | 
					
						
						|  | def sample(self): | 
					
						
						|  | x = self.mean + self.std * torch.randn_like(self.mean) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  | def kl(self, other=None, dims=(1, 2)): | 
					
						
						|  | if self.deterministic: | 
					
						
						|  | return torch.Tensor([0.]) | 
					
						
						|  | else: | 
					
						
						|  | if other is None: | 
					
						
						|  | return 0.5 * torch.mean(torch.pow(self.mean, 2) | 
					
						
						|  | + self.var - 1.0 - self.logvar, | 
					
						
						|  | dim=dims) | 
					
						
						|  | else: | 
					
						
						|  | return 0.5 * torch.mean( | 
					
						
						|  | torch.pow(self.mean - other.mean, 2) / other.var | 
					
						
						|  | + self.var / other.var - 1.0 - self.logvar + other.logvar, | 
					
						
						|  | dim=dims) | 
					
						
						|  |  | 
					
						
						|  | def nll(self, sample, dims=(1, 2)): | 
					
						
						|  | if self.deterministic: | 
					
						
						|  | return torch.Tensor([0.]) | 
					
						
						|  | logtwopi = np.log(2.0 * np.pi) | 
					
						
						|  | return 0.5 * torch.sum( | 
					
						
						|  | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, | 
					
						
						|  | dim=dims) | 
					
						
						|  |  | 
					
						
						|  | def mode(self): | 
					
						
						|  | return self.mean | 
					
						
						|  |  |