Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from torch.cuda.amp import custom_bwd, custom_fwd | |
| try: | |
| import _shencoder as _backend | |
| except ImportError: | |
| from .backend import _backend | |
| class _sh_encoder(Function): | |
| # force float32 for better precision | |
| def forward(ctx, inputs, degree, calc_grad_inputs=False): | |
| # inputs: [B, input_dim], float in [-1, 1] | |
| # RETURN: [B, F], float | |
| inputs = inputs.contiguous() | |
| B, input_dim = inputs.shape # batch size, coord dim | |
| output_dim = degree ** 2 | |
| outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) | |
| if calc_grad_inputs: | |
| dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) | |
| else: | |
| dy_dx = None | |
| _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) | |
| ctx.save_for_backward(inputs, dy_dx) | |
| ctx.dims = [B, input_dim, degree] | |
| return outputs | |
| #@once_differentiable | |
| def backward(ctx, grad): | |
| # grad: [B, C * C] | |
| inputs, dy_dx = ctx.saved_tensors | |
| if dy_dx is not None: | |
| grad = grad.contiguous() | |
| B, input_dim, degree = ctx.dims | |
| grad_inputs = torch.zeros_like(inputs) | |
| _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) | |
| return grad_inputs, None, None | |
| else: | |
| return None, None, None | |
| sh_encode = _sh_encoder.apply | |
| class SHEncoder(nn.Module): | |
| def __init__(self, input_dim=3, degree=4): | |
| super().__init__() | |
| self.input_dim = input_dim # coord dims, must be 3 | |
| self.degree = degree # 0 ~ 4 | |
| self.output_dim = degree ** 2 | |
| assert self.input_dim == 3, "SH encoder only support input dim == 3" | |
| assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" | |
| def __repr__(self): | |
| return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" | |
| def forward(self, inputs, size=1): | |
| # inputs: [..., input_dim], normalized real world positions in [-size, size] | |
| # return: [..., degree^2] | |
| inputs = inputs / size # [-1, 1] | |
| prefix_shape = list(inputs.shape[:-1]) | |
| inputs = inputs.reshape(-1, self.input_dim) | |
| outputs = sh_encode(inputs, self.degree, inputs.requires_grad) | |
| outputs = outputs.reshape(prefix_shape + [self.output_dim]) | |
| return outputs |