ReconViaGen / trellis /renderers /gsplat_renderer.py
Stable-X's picture
Upload 288 files
f3ff4f1 verified
import gsplat as gs
import numpy as np
import torch
import torch.nn.functional as F
from easydict import EasyDict as edict
class GSplatRenderer:
def __init__(self, rendering_options={}) -> None:
self.pipe = edict({
"kernel_size": 0.1,
"convert_SHs_python": False,
"compute_cov3D_python": False,
"scale_modifier": 1.0,
"debug": False,
"use_mip_gaussian": True
})
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"bg_color": 'random',
})
self.rendering_options.update(rendering_options)
self.bg_color = None
def render(
self,
gaussian,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
colors_overwrite: torch.Tensor = None
) -> edict:
resolution = self.rendering_options["resolution"]
ssaa = self.rendering_options["ssaa"]
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(
self.rendering_options["bg_color"],
dtype=torch.float32,
device="cuda"
)
height = resolution * ssaa
width = resolution * ssaa
# Set up background color
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(
self.rendering_options["bg_color"],
dtype=torch.float32,
device="cuda"
)
Ks_scaled = intrinsics.clone()
Ks_scaled[0, 0] *= width
Ks_scaled[1, 1] *= height
Ks_scaled[0, 2] *= width
Ks_scaled[1, 2] *= height
Ks_scaled = Ks_scaled.unsqueeze(0)
near_plane = 0.01
far_plane = 1000.0
# Rasterize with gsplat
render_colors, render_alphas, meta = gs.rasterization(
means=gaussian.get_xyz,
quats=F.normalize(gaussian.get_rotation, dim=-1),
scales=gaussian.get_scaling / intrinsics[0, 0],
opacities=gaussian.get_opacity.squeeze(-1),
colors=colors_overwrite.unsqueeze(0) if colors_overwrite is not None else torch.sigmoid(
gaussian.get_features.squeeze(1)).unsqueeze(0),
viewmats=extrinsics.unsqueeze(0),
Ks=Ks_scaled,
width=width,
height=height,
near_plane=near_plane,
far_plane=far_plane,
radius_clip=3.0,
eps2d=0.3,
render_mode="RGB",
backgrounds=self.bg_color.unsqueeze(0),
camera_model="pinhole"
)
rendered_image = render_colors[0, ..., 0:3].permute(2, 0, 1)
# Apply supersampling if needed
if ssaa > 1:
rendered_image = F.interpolate(
rendered_image[None],
size=(resolution, resolution),
mode='bilinear',
align_corners=False,
antialias=True
).squeeze()
return edict({'color': rendered_image})