import gsplat as gs import numpy as np import torch import torch.nn.functional as F from easydict import EasyDict as edict class GSplatRenderer: def __init__(self, rendering_options={}) -> None: self.pipe = edict({ "kernel_size": 0.1, "convert_SHs_python": False, "compute_cov3D_python": False, "scale_modifier": 1.0, "debug": False, "use_mip_gaussian": True }) self.rendering_options = edict({ "resolution": None, "near": None, "far": None, "ssaa": 1, "bg_color": 'random', }) self.rendering_options.update(rendering_options) self.bg_color = None def render( self, gaussian, extrinsics: torch.Tensor, intrinsics: torch.Tensor, colors_overwrite: torch.Tensor = None ) -> edict: resolution = self.rendering_options["resolution"] ssaa = self.rendering_options["ssaa"] if self.rendering_options["bg_color"] == 'random': self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") if np.random.rand() < 0.5: self.bg_color += 1 else: self.bg_color = torch.tensor( self.rendering_options["bg_color"], dtype=torch.float32, device="cuda" ) height = resolution * ssaa width = resolution * ssaa # Set up background color if self.rendering_options["bg_color"] == 'random': self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") if np.random.rand() < 0.5: self.bg_color += 1 else: self.bg_color = torch.tensor( self.rendering_options["bg_color"], dtype=torch.float32, device="cuda" ) Ks_scaled = intrinsics.clone() Ks_scaled[0, 0] *= width Ks_scaled[1, 1] *= height Ks_scaled[0, 2] *= width Ks_scaled[1, 2] *= height Ks_scaled = Ks_scaled.unsqueeze(0) near_plane = 0.01 far_plane = 1000.0 # Rasterize with gsplat render_colors, render_alphas, meta = gs.rasterization( means=gaussian.get_xyz, quats=F.normalize(gaussian.get_rotation, dim=-1), scales=gaussian.get_scaling / intrinsics[0, 0], opacities=gaussian.get_opacity.squeeze(-1), colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid( gaussian.get_features.squeeze(1)).unsqueeze(0), viewmats=extrinsics.unsqueeze(0), Ks=Ks_scaled, width=width, height=height, near_plane=near_plane, far_plane=far_plane, radius_clip=3.0, eps2d=0.3, render_mode="RGB", backgrounds=self.bg_color.unsqueeze(0), camera_model="pinhole" ) rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1) # Apply supersampling if needed if ssaa > 1: rendered_image = F.interpolate( rendered_image[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True ).squeeze() return edict({'color': rendered_image})