Spaces:
Paused
Paused
Create mesh_optim.py
Browse files- freesplatter/utils/mesh_optim.py +203 -0
freesplatter/utils/mesh_optim.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import utils3d
|
| 5 |
+
import nvdiffrast.torch as dr
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import trimesh
|
| 8 |
+
import trimesh.visual
|
| 9 |
+
import xatlas
|
| 10 |
+
import cv2
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import fast_simplification
|
| 13 |
+
|
| 14 |
+
from freesplatter.utils.mesh import Mesh
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parametrize_mesh(vertices: np.array, faces: np.array):
|
| 18 |
+
"""
|
| 19 |
+
Parametrize a mesh to a texture space, using xatlas.
|
| 20 |
+
Args:
|
| 21 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
| 22 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
| 26 |
+
|
| 27 |
+
vertices = vertices[vmapping]
|
| 28 |
+
faces = indices
|
| 29 |
+
|
| 30 |
+
return vertices, faces, uvs
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def bake_texture(
|
| 34 |
+
vertices: np.array,
|
| 35 |
+
faces: np.array,
|
| 36 |
+
uvs: np.array,
|
| 37 |
+
observations: List[np.array],
|
| 38 |
+
masks: List[np.array],
|
| 39 |
+
extrinsics: List[np.array],
|
| 40 |
+
intrinsics: List[np.array],
|
| 41 |
+
texture_size: int = 2048,
|
| 42 |
+
near: float = 0.1,
|
| 43 |
+
far: float = 10.0,
|
| 44 |
+
mode: Literal['fast', 'opt'] = 'opt',
|
| 45 |
+
lambda_tv: float = 1e-2,
|
| 46 |
+
verbose: bool = False,
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Bake texture to a mesh from multiple observations.
|
| 50 |
+
Args:
|
| 51 |
+
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
| 52 |
+
faces (np.array): Faces of the mesh. Shape (F, 3).
|
| 53 |
+
uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
|
| 54 |
+
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
|
| 55 |
+
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
|
| 56 |
+
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
|
| 57 |
+
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
|
| 58 |
+
texture_size (int): Size of the texture.
|
| 59 |
+
near (float): Near plane of the camera.
|
| 60 |
+
far (float): Far plane of the camera.
|
| 61 |
+
mode (Literal['fast', 'opt']): Mode of texture baking.
|
| 62 |
+
lambda_tv (float): Weight of total variation loss in optimization.
|
| 63 |
+
verbose (bool): Whether to print progress.
|
| 64 |
+
"""
|
| 65 |
+
vertices = torch.tensor(vertices).float().cuda()
|
| 66 |
+
faces = torch.tensor(faces.astype(np.int32)).cuda()
|
| 67 |
+
uvs = torch.tensor(uvs).float().cuda()
|
| 68 |
+
observations = [torch.tensor(obs).float().cuda() for obs in observations]
|
| 69 |
+
masks = [torch.tensor(m>1e-2).bool().cuda() for m in masks]
|
| 70 |
+
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).float().cuda()) for extr in extrinsics]
|
| 71 |
+
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).float().cuda(), near, far) for intr in intrinsics]
|
| 72 |
+
|
| 73 |
+
if mode == 'fast':
|
| 74 |
+
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
|
| 75 |
+
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
|
| 76 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
| 77 |
+
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
| 80 |
+
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
| 81 |
+
)
|
| 82 |
+
uv_map = rast['uv'][0].detach().flip(0)
|
| 83 |
+
mask = rast['mask'][0].detach().bool() & masks[0]
|
| 84 |
+
|
| 85 |
+
# nearest neighbor interpolation
|
| 86 |
+
uv_map = (uv_map * texture_size).floor().long()
|
| 87 |
+
obs = observation[mask]
|
| 88 |
+
uv_map = uv_map[mask]
|
| 89 |
+
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
| 90 |
+
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
|
| 91 |
+
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
|
| 92 |
+
|
| 93 |
+
mask = texture_weights > 0
|
| 94 |
+
texture[mask] /= texture_weights[mask][:, None]
|
| 95 |
+
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
| 96 |
+
|
| 97 |
+
# inpaint
|
| 98 |
+
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
|
| 99 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
| 100 |
+
|
| 101 |
+
elif mode == 'opt':
|
| 102 |
+
rastctx = utils3d.torch.RastContext(backend='cuda')
|
| 103 |
+
observations = [observations.flip(0) for observations in observations]
|
| 104 |
+
masks = [m.flip(0) for m in masks]
|
| 105 |
+
_uv = []
|
| 106 |
+
_uv_dr = []
|
| 107 |
+
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
rast = utils3d.torch.rasterize_triangle_faces(
|
| 110 |
+
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
| 111 |
+
)
|
| 112 |
+
_uv.append(rast['uv'].detach())
|
| 113 |
+
_uv_dr.append(rast['uv_dr'].detach())
|
| 114 |
+
|
| 115 |
+
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
|
| 116 |
+
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
| 117 |
+
|
| 118 |
+
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
| 119 |
+
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
| 120 |
+
|
| 121 |
+
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
| 122 |
+
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
| 123 |
+
|
| 124 |
+
def tv_loss(texture):
|
| 125 |
+
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
|
| 126 |
+
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
|
| 127 |
+
|
| 128 |
+
total_steps = 2500
|
| 129 |
+
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
|
| 130 |
+
for step in range(total_steps):
|
| 131 |
+
optimizer.zero_grad()
|
| 132 |
+
selected = np.random.randint(0, len(views))
|
| 133 |
+
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
|
| 134 |
+
render = dr.texture(texture, uv, uv_dr)[0]
|
| 135 |
+
loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
|
| 136 |
+
if lambda_tv > 0:
|
| 137 |
+
loss += lambda_tv * tv_loss(texture)
|
| 138 |
+
loss.backward()
|
| 139 |
+
optimizer.step()
|
| 140 |
+
# annealing
|
| 141 |
+
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
|
| 142 |
+
pbar.set_postfix({'loss': loss.item()})
|
| 143 |
+
pbar.update()
|
| 144 |
+
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
| 145 |
+
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
| 146 |
+
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
|
| 147 |
+
)['mask'][0].detach().cpu().numpy().astype(np.uint8)
|
| 148 |
+
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f'Unknown mode: {mode}')
|
| 151 |
+
|
| 152 |
+
return texture
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def optimize_mesh(
|
| 156 |
+
mesh: Mesh,
|
| 157 |
+
images: torch.Tensor,
|
| 158 |
+
masks: torch.Tensor,
|
| 159 |
+
extrinsics: torch.Tensor,
|
| 160 |
+
intrinsics: torch.Tensor,
|
| 161 |
+
simplify: float = 0.95,
|
| 162 |
+
texture_size: int = 1024,
|
| 163 |
+
verbose: bool = False,
|
| 164 |
+
) -> trimesh.Trimesh:
|
| 165 |
+
"""
|
| 166 |
+
Convert a generated asset to a glb file.
|
| 167 |
+
Args:
|
| 168 |
+
mesh (Mesh): Extracted mesh.
|
| 169 |
+
simplify (float): Ratio of faces to remove in simplification.
|
| 170 |
+
texture_size (int): Size of the texture.
|
| 171 |
+
verbose (bool): Whether to print progress.
|
| 172 |
+
"""
|
| 173 |
+
vertices = mesh.v.cpu().numpy()
|
| 174 |
+
faces = mesh.f.cpu().numpy()
|
| 175 |
+
|
| 176 |
+
# mesh simplification
|
| 177 |
+
max_faces = 50000
|
| 178 |
+
mesh_reduction = max(1 - max_faces / faces.shape[0], simplify)
|
| 179 |
+
vertices, faces = fast_simplification.simplify(
|
| 180 |
+
vertices, faces, target_reduction=mesh_reduction)
|
| 181 |
+
|
| 182 |
+
# parametrize mesh
|
| 183 |
+
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
| 184 |
+
|
| 185 |
+
# bake texture
|
| 186 |
+
images = [images[i].cpu().numpy() for i in range(len(images))]
|
| 187 |
+
masks = [masks[i].cpu().numpy() for i in range(len(masks))]
|
| 188 |
+
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
| 189 |
+
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
| 190 |
+
texture = bake_texture(
|
| 191 |
+
vertices.astype(float), faces.astype(float), uvs,
|
| 192 |
+
images, masks, extrinsics, intrinsics,
|
| 193 |
+
texture_size=texture_size,
|
| 194 |
+
mode='opt',
|
| 195 |
+
lambda_tv=0.01,
|
| 196 |
+
verbose=verbose
|
| 197 |
+
)
|
| 198 |
+
texture = Image.fromarray(texture)
|
| 199 |
+
|
| 200 |
+
# rotate mesh
|
| 201 |
+
vertices = vertices.astype(float) @ np.array([[-1, 0, 0], [0, 0, 1], [0, 1, 0]]).astype(float)
|
| 202 |
+
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture))
|
| 203 |
+
return mesh
|