Mariam-Elz commited on
Commit
a320c78
·
verified ·
1 Parent(s): 3ab0227

Upload mesh.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mesh.py +845 -845
mesh.py CHANGED
@@ -1,845 +1,845 @@
1
- import os
2
- import cv2
3
- import torch
4
- import trimesh
5
- import numpy as np
6
-
7
- from kiui.op import safe_normalize, dot
8
- from kiui.typing import *
9
-
10
- class Mesh:
11
- """
12
- A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
13
-
14
- Note:
15
- This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
16
- """
17
- def __init__(
18
- self,
19
- v: Optional[Tensor] = None,
20
- f: Optional[Tensor] = None,
21
- vn: Optional[Tensor] = None,
22
- fn: Optional[Tensor] = None,
23
- vt: Optional[Tensor] = None,
24
- ft: Optional[Tensor] = None,
25
- vc: Optional[Tensor] = None, # vertex color
26
- albedo: Optional[Tensor] = None,
27
- metallicRoughness: Optional[Tensor] = None,
28
- device: Optional[torch.device] = None,
29
- ):
30
- """Init a mesh directly using all attributes.
31
-
32
- Args:
33
- v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
34
- f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
35
- vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
36
- fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
37
- vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
38
- ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
39
- vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
40
- albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
41
- metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
42
- device (Optional[torch.device]): torch device. Defaults to None.
43
- """
44
- self.device = device
45
- self.v = v
46
- self.vn = vn
47
- self.vt = vt
48
- self.f = f
49
- self.fn = fn
50
- self.ft = ft
51
- # will first see if there is vertex color to use
52
- self.vc = vc
53
- # only support a single albedo image
54
- self.albedo = albedo
55
- # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
56
- # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
57
- self.metallicRoughness = metallicRoughness
58
-
59
- self.ori_center = 0
60
- self.ori_scale = 1
61
-
62
- @classmethod
63
- def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
64
- """load mesh from path.
65
-
66
- Args:
67
- path (str): path to mesh file, supports ply, obj, glb.
68
- clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
69
- resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
70
- renormal (bool, optional): re-calc the vertex normals. Defaults to True.
71
- retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
72
- bound (float, optional): bound to resize. Defaults to 0.9.
73
- front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
74
- device (torch.device, optional): torch device. Defaults to None.
75
-
76
- Note:
77
- a ``device`` keyword argument can be provided to specify the torch device.
78
- If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
79
-
80
- Returns:
81
- Mesh: the loaded Mesh object.
82
- """
83
- # obj supports face uv
84
- if path.endswith(".obj"):
85
- mesh = cls.load_obj(path, **kwargs)
86
- # trimesh only supports vertex uv, but can load more formats
87
- else:
88
- mesh = cls.load_trimesh(path, **kwargs)
89
-
90
- # clean
91
- if clean:
92
- from kiui.mesh_utils import clean_mesh
93
- vertices = mesh.v.detach().cpu().numpy()
94
- triangles = mesh.f.detach().cpu().numpy()
95
- vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
96
- mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
97
- mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
98
-
99
- print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
100
- # auto-normalize
101
- if resize:
102
- mesh.auto_size(bound=bound)
103
- # auto-fix normal
104
- if renormal or mesh.vn is None:
105
- mesh.auto_normal()
106
- print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
107
- # auto-fix texcoords
108
- if retex or (mesh.albedo is not None and mesh.vt is None):
109
- mesh.auto_uv(cache_path=path)
110
- print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
111
-
112
- # rotate front dir to +z
113
- if front_dir != "+z":
114
- # axis switch
115
- if "-z" in front_dir:
116
- T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
117
- elif "+x" in front_dir:
118
- T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
119
- elif "-x" in front_dir:
120
- T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
121
- elif "+y" in front_dir:
122
- T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
123
- elif "-y" in front_dir:
124
- T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
125
- else:
126
- T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
127
- # rotation (how many 90 degrees)
128
- if '1' in front_dir:
129
- T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
130
- elif '2' in front_dir:
131
- T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
132
- elif '3' in front_dir:
133
- T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
134
- mesh.v @= T
135
- mesh.vn @= T
136
-
137
- return mesh
138
-
139
- # load from obj file
140
- @classmethod
141
- def load_obj(cls, path, albedo_path=None, device=None):
142
- """load an ``obj`` mesh.
143
-
144
- Args:
145
- path (str): path to mesh.
146
- albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
147
- device (torch.device, optional): torch device. Defaults to None.
148
-
149
- Note:
150
- We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
151
- The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
152
-
153
- Returns:
154
- Mesh: the loaded Mesh object.
155
- """
156
- assert os.path.splitext(path)[-1] == ".obj"
157
-
158
- mesh = cls()
159
-
160
- # device
161
- if device is None:
162
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
-
164
- mesh.device = device
165
-
166
- # load obj
167
- with open(path, "r") as f:
168
- lines = f.readlines()
169
-
170
- def parse_f_v(fv):
171
- # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
172
- # supported forms:
173
- # f v1 v2 v3
174
- # f v1/vt1 v2/vt2 v3/vt3
175
- # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
176
- # f v1//vn1 v2//vn2 v3//vn3
177
- xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
178
- xs.extend([-1] * (3 - len(xs)))
179
- return xs[0], xs[1], xs[2]
180
-
181
- vertices, texcoords, normals = [], [], []
182
- faces, tfaces, nfaces = [], [], []
183
- mtl_path = None
184
-
185
- for line in lines:
186
- split_line = line.split()
187
- # empty line
188
- if len(split_line) == 0:
189
- continue
190
- prefix = split_line[0].lower()
191
- # mtllib
192
- if prefix == "mtllib":
193
- mtl_path = split_line[1]
194
- # usemtl
195
- elif prefix == "usemtl":
196
- pass # ignored
197
- # v/vn/vt
198
- elif prefix == "v":
199
- vertices.append([float(v) for v in split_line[1:]])
200
- elif prefix == "vn":
201
- normals.append([float(v) for v in split_line[1:]])
202
- elif prefix == "vt":
203
- val = [float(v) for v in split_line[1:]]
204
- texcoords.append([val[0], 1.0 - val[1]])
205
- elif prefix == "f":
206
- vs = split_line[1:]
207
- nv = len(vs)
208
- v0, t0, n0 = parse_f_v(vs[0])
209
- for i in range(nv - 2): # triangulate (assume vertices are ordered)
210
- v1, t1, n1 = parse_f_v(vs[i + 1])
211
- v2, t2, n2 = parse_f_v(vs[i + 2])
212
- faces.append([v0, v1, v2])
213
- tfaces.append([t0, t1, t2])
214
- nfaces.append([n0, n1, n2])
215
-
216
- mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
217
- mesh.vt = (
218
- torch.tensor(texcoords, dtype=torch.float32, device=device)
219
- if len(texcoords) > 0
220
- else None
221
- )
222
- mesh.vn = (
223
- torch.tensor(normals, dtype=torch.float32, device=device)
224
- if len(normals) > 0
225
- else None
226
- )
227
-
228
- mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
229
- mesh.ft = (
230
- torch.tensor(tfaces, dtype=torch.int32, device=device)
231
- if len(texcoords) > 0
232
- else None
233
- )
234
- mesh.fn = (
235
- torch.tensor(nfaces, dtype=torch.int32, device=device)
236
- if len(normals) > 0
237
- else None
238
- )
239
-
240
- # see if there is vertex color
241
- use_vertex_color = False
242
- if mesh.v.shape[1] == 6:
243
- use_vertex_color = True
244
- mesh.vc = mesh.v[:, 3:]
245
- mesh.v = mesh.v[:, :3]
246
- print(f"[load_obj] use vertex color: {mesh.vc.shape}")
247
-
248
- # try to load texture image
249
- if not use_vertex_color:
250
- # try to retrieve mtl file
251
- mtl_path_candidates = []
252
- if mtl_path is not None:
253
- mtl_path_candidates.append(mtl_path)
254
- mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
255
- mtl_path_candidates.append(path.replace(".obj", ".mtl"))
256
-
257
- mtl_path = None
258
- for candidate in mtl_path_candidates:
259
- if os.path.exists(candidate):
260
- mtl_path = candidate
261
- break
262
-
263
- # if albedo_path is not provided, try retrieve it from mtl
264
- metallic_path = None
265
- roughness_path = None
266
- if mtl_path is not None and albedo_path is None:
267
- with open(mtl_path, "r") as f:
268
- lines = f.readlines()
269
-
270
- for line in lines:
271
- split_line = line.split()
272
- # empty line
273
- if len(split_line) == 0:
274
- continue
275
- prefix = split_line[0]
276
-
277
- if "map_Kd" in prefix:
278
- # assume relative path!
279
- albedo_path = os.path.join(os.path.dirname(path), split_line[1])
280
- print(f"[load_obj] use texture from: {albedo_path}")
281
- elif "map_Pm" in prefix:
282
- metallic_path = os.path.join(os.path.dirname(path), split_line[1])
283
- elif "map_Pr" in prefix:
284
- roughness_path = os.path.join(os.path.dirname(path), split_line[1])
285
-
286
- # still not found albedo_path, or the path doesn't exist
287
- if albedo_path is None or not os.path.exists(albedo_path):
288
- # init an empty texture
289
- print(f"[load_obj] init empty albedo!")
290
- # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
291
- albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
292
- else:
293
- albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
294
- albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
295
- albedo = albedo.astype(np.float32) / 255
296
- print(f"[load_obj] load texture: {albedo.shape}")
297
-
298
- mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
299
-
300
- # try to load metallic and roughness
301
- if metallic_path is not None and roughness_path is not None:
302
- print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
303
- metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
304
- metallic = metallic.astype(np.float32) / 255
305
- roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
306
- roughness = roughness.astype(np.float32) / 255
307
- metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
308
-
309
- mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
310
-
311
- return mesh
312
-
313
- @classmethod
314
- def load_trimesh(cls, path, device=None):
315
- """load a mesh using ``trimesh.load()``.
316
-
317
- Can load various formats like ``glb`` and serves as a fallback.
318
-
319
- Note:
320
- We will try to merge all meshes if the glb contains more than one,
321
- but **this may cause the texture to lose**, since we only support one texture image!
322
-
323
- Args:
324
- path (str): path to the mesh file.
325
- device (torch.device, optional): torch device. Defaults to None.
326
-
327
- Returns:
328
- Mesh: the loaded Mesh object.
329
- """
330
- mesh = cls()
331
-
332
- # device
333
- if device is None:
334
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335
-
336
- mesh.device = device
337
-
338
- # use trimesh to load ply/glb
339
- _data = trimesh.load(path)
340
- if isinstance(_data, trimesh.Scene):
341
- if len(_data.geometry) == 1:
342
- _mesh = list(_data.geometry.values())[0]
343
- else:
344
- print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
345
- _concat = []
346
- # loop the scene graph and apply transform to each mesh
347
- scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
348
- for k, v in scene_graph.items():
349
- name = v['geometry']
350
- if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
351
- transform = v['transform']
352
- _concat.append(_data.geometry[name].apply_transform(transform))
353
- _mesh = trimesh.util.concatenate(_concat)
354
- else:
355
- _mesh = _data
356
-
357
- if _mesh.visual.kind == 'vertex':
358
- vertex_colors = _mesh.visual.vertex_colors
359
- vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
360
- mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
361
- print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
362
- elif _mesh.visual.kind == 'texture':
363
- _material = _mesh.visual.material
364
- if isinstance(_material, trimesh.visual.material.PBRMaterial):
365
- texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
366
- # load metallicRoughness if present
367
- if _material.metallicRoughnessTexture is not None:
368
- metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
369
- mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
370
- elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
371
- texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
372
- else:
373
- raise NotImplementedError(f"material type {type(_material)} not supported!")
374
- mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
375
- print(f"[load_trimesh] load texture: {texture.shape}")
376
- else:
377
- texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
378
- mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
379
- print(f"[load_trimesh] failed to load texture.")
380
-
381
- vertices = _mesh.vertices
382
-
383
- try:
384
- texcoords = _mesh.visual.uv
385
- texcoords[:, 1] = 1 - texcoords[:, 1]
386
- except Exception as e:
387
- texcoords = None
388
-
389
- try:
390
- normals = _mesh.vertex_normals
391
- except Exception as e:
392
- normals = None
393
-
394
- # trimesh only support vertex uv...
395
- faces = tfaces = nfaces = _mesh.faces
396
-
397
- mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
398
- mesh.vt = (
399
- torch.tensor(texcoords, dtype=torch.float32, device=device)
400
- if texcoords is not None
401
- else None
402
- )
403
- mesh.vn = (
404
- torch.tensor(normals, dtype=torch.float32, device=device)
405
- if normals is not None
406
- else None
407
- )
408
-
409
- mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
410
- mesh.ft = (
411
- torch.tensor(tfaces, dtype=torch.int32, device=device)
412
- if texcoords is not None
413
- else None
414
- )
415
- mesh.fn = (
416
- torch.tensor(nfaces, dtype=torch.int32, device=device)
417
- if normals is not None
418
- else None
419
- )
420
-
421
- return mesh
422
-
423
- # sample surface (using trimesh)
424
- def sample_surface(self, count: int):
425
- """sample points on the surface of the mesh.
426
-
427
- Args:
428
- count (int): number of points to sample.
429
-
430
- Returns:
431
- torch.Tensor: the sampled points, float [count, 3].
432
- """
433
- _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
434
- points, face_idx = trimesh.sample.sample_surface(_mesh, count)
435
- points = torch.from_numpy(points).float().to(self.device)
436
- return points
437
-
438
- # aabb
439
- def aabb(self):
440
- """get the axis-aligned bounding box of the mesh.
441
-
442
- Returns:
443
- Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
444
- """
445
- return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
446
-
447
- # unit size
448
- @torch.no_grad()
449
- def auto_size(self, bound=0.9):
450
- """auto resize the mesh.
451
-
452
- Args:
453
- bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
454
- """
455
- vmin, vmax = self.aabb()
456
- self.ori_center = (vmax + vmin) / 2
457
- self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
458
- self.v = (self.v - self.ori_center) * self.ori_scale
459
-
460
- def auto_normal(self):
461
- """auto calculate the vertex normals.
462
- """
463
- i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
464
- v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
465
-
466
- face_normals = torch.cross(v1 - v0, v2 - v0)
467
-
468
- # Splat face normals to vertices
469
- vn = torch.zeros_like(self.v)
470
- vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
471
- vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
472
- vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
473
-
474
- # Normalize, replace zero (degenerated) normals with some default value
475
- vn = torch.where(
476
- dot(vn, vn) > 1e-20,
477
- vn,
478
- torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
479
- )
480
- vn = safe_normalize(vn)
481
-
482
- self.vn = vn
483
- self.fn = self.f
484
-
485
- def auto_uv(self, cache_path=None, vmap=True):
486
- """auto calculate the uv coordinates.
487
-
488
- Args:
489
- cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
490
- vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
491
- Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
492
- """
493
- # try to load cache
494
- if cache_path is not None:
495
- cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
496
- if cache_path is not None and os.path.exists(cache_path):
497
- data = np.load(cache_path)
498
- vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
499
- else:
500
- import xatlas
501
-
502
- v_np = self.v.detach().cpu().numpy()
503
- f_np = self.f.detach().int().cpu().numpy()
504
- atlas = xatlas.Atlas()
505
- atlas.add_mesh(v_np, f_np)
506
- chart_options = xatlas.ChartOptions()
507
- # chart_options.max_iterations = 4
508
- atlas.generate(chart_options=chart_options)
509
- vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
510
-
511
- # save to cache
512
- if cache_path is not None:
513
- np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
514
-
515
- vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
516
- ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
517
- self.vt = vt
518
- self.ft = ft
519
-
520
- if vmap:
521
- vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
522
- self.align_v_to_vt(vmapping)
523
-
524
- def align_v_to_vt(self, vmapping=None):
525
- """ remap v/f and vn/fn to vt/ft.
526
-
527
- Args:
528
- vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
529
- """
530
- if vmapping is None:
531
- ft = self.ft.view(-1).long()
532
- f = self.f.view(-1).long()
533
- vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
534
- vmapping[ft] = f # scatter, randomly choose one if index is not unique
535
-
536
- self.v = self.v[vmapping]
537
- self.f = self.ft
538
-
539
- if self.vn is not None:
540
- self.vn = self.vn[vmapping]
541
- self.fn = self.ft
542
-
543
- def to(self, device):
544
- """move all tensor attributes to device.
545
-
546
- Args:
547
- device (torch.device): target device.
548
-
549
- Returns:
550
- Mesh: self.
551
- """
552
- self.device = device
553
- for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
554
- tensor = getattr(self, name)
555
- if tensor is not None:
556
- setattr(self, name, tensor.to(device))
557
- return self
558
-
559
- def write(self, path):
560
- """write the mesh to a path.
561
-
562
- Args:
563
- path (str): path to write, supports ply, obj and glb.
564
- """
565
- if path.endswith(".ply"):
566
- self.write_ply(path)
567
- elif path.endswith(".obj"):
568
- self.write_obj(path)
569
- elif path.endswith(".glb") or path.endswith(".gltf"):
570
- self.write_glb(path)
571
- else:
572
- raise NotImplementedError(f"format {path} not supported!")
573
-
574
- def write_ply(self, path):
575
- """write the mesh in ply format. Only for geometry!
576
-
577
- Args:
578
- path (str): path to write.
579
- """
580
-
581
- if self.albedo is not None:
582
- print(f'[WARN] ply format does not support exporting texture, will ignore!')
583
-
584
- v_np = self.v.detach().cpu().numpy()
585
- f_np = self.f.detach().cpu().numpy()
586
-
587
- _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
588
- _mesh.export(path)
589
-
590
-
591
- def write_glb(self, path):
592
- """write the mesh in glb/gltf format.
593
- This will create a scene with a single mesh.
594
-
595
- Args:
596
- path (str): path to write.
597
- """
598
-
599
- # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
600
- if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
601
- self.align_v_to_vt()
602
-
603
- import pygltflib
604
-
605
- f_np = self.f.detach().cpu().numpy().astype(np.uint32)
606
- f_np_blob = f_np.flatten().tobytes()
607
-
608
- v_np = self.v.detach().cpu().numpy().astype(np.float32)
609
- v_np_blob = v_np.tobytes()
610
-
611
- blob = f_np_blob + v_np_blob
612
- byteOffset = len(blob)
613
-
614
- # base mesh
615
- gltf = pygltflib.GLTF2(
616
- scene=0,
617
- scenes=[pygltflib.Scene(nodes=[0])],
618
- nodes=[pygltflib.Node(mesh=0)],
619
- meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
620
- # indices to accessors (0 is triangles)
621
- attributes=pygltflib.Attributes(
622
- POSITION=1,
623
- ),
624
- indices=0,
625
- )])],
626
- buffers=[
627
- pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
628
- ],
629
- # buffer view (based on dtype)
630
- bufferViews=[
631
- # triangles; as flatten (element) array
632
- pygltflib.BufferView(
633
- buffer=0,
634
- byteLength=len(f_np_blob),
635
- target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
636
- ),
637
- # positions; as vec3 array
638
- pygltflib.BufferView(
639
- buffer=0,
640
- byteOffset=len(f_np_blob),
641
- byteLength=len(v_np_blob),
642
- byteStride=12, # vec3
643
- target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
644
- ),
645
- ],
646
- accessors=[
647
- # 0 = triangles
648
- pygltflib.Accessor(
649
- bufferView=0,
650
- componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
651
- count=f_np.size,
652
- type=pygltflib.SCALAR,
653
- max=[int(f_np.max())],
654
- min=[int(f_np.min())],
655
- ),
656
- # 1 = positions
657
- pygltflib.Accessor(
658
- bufferView=1,
659
- componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
660
- count=len(v_np),
661
- type=pygltflib.VEC3,
662
- max=v_np.max(axis=0).tolist(),
663
- min=v_np.min(axis=0).tolist(),
664
- ),
665
- ],
666
- )
667
-
668
- # append texture info
669
- if self.vt is not None:
670
-
671
- vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
672
- vt_np_blob = vt_np.tobytes()
673
-
674
- albedo = self.albedo.detach().cpu().numpy()
675
- albedo = (albedo * 255).astype(np.uint8)
676
- albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
677
- albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
678
-
679
- # update primitive
680
- gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
681
- gltf.meshes[0].primitives[0].material = 0
682
-
683
- # update materials
684
- gltf.materials.append(pygltflib.Material(
685
- pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
686
- baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
687
- metallicFactor=0.0,
688
- roughnessFactor=1.0,
689
- ),
690
- alphaMode=pygltflib.OPAQUE,
691
- alphaCutoff=None,
692
- doubleSided=True,
693
- ))
694
-
695
- gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
696
- gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
697
- gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
698
-
699
- # update buffers
700
- gltf.bufferViews.append(
701
- # index = 2, texcoords; as vec2 array
702
- pygltflib.BufferView(
703
- buffer=0,
704
- byteOffset=byteOffset,
705
- byteLength=len(vt_np_blob),
706
- byteStride=8, # vec2
707
- target=pygltflib.ARRAY_BUFFER,
708
- )
709
- )
710
-
711
- gltf.accessors.append(
712
- # 2 = texcoords
713
- pygltflib.Accessor(
714
- bufferView=2,
715
- componentType=pygltflib.FLOAT,
716
- count=len(vt_np),
717
- type=pygltflib.VEC2,
718
- max=vt_np.max(axis=0).tolist(),
719
- min=vt_np.min(axis=0).tolist(),
720
- )
721
- )
722
-
723
- blob += vt_np_blob
724
- byteOffset += len(vt_np_blob)
725
-
726
- gltf.bufferViews.append(
727
- # index = 3, albedo texture; as none target
728
- pygltflib.BufferView(
729
- buffer=0,
730
- byteOffset=byteOffset,
731
- byteLength=len(albedo_blob),
732
- )
733
- )
734
-
735
- blob += albedo_blob
736
- byteOffset += len(albedo_blob)
737
-
738
- gltf.buffers[0].byteLength = byteOffset
739
-
740
- # append metllic roughness
741
- if self.metallicRoughness is not None:
742
- metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
743
- metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
744
- metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
745
- metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
746
-
747
- # update texture definition
748
- gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
749
- gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
750
- gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
751
-
752
- gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
753
- gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
754
- gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
755
-
756
- # update buffers
757
- gltf.bufferViews.append(
758
- # index = 4, metallicRoughness texture; as none target
759
- pygltflib.BufferView(
760
- buffer=0,
761
- byteOffset=byteOffset,
762
- byteLength=len(metallicRoughness_blob),
763
- )
764
- )
765
-
766
- blob += metallicRoughness_blob
767
- byteOffset += len(metallicRoughness_blob)
768
-
769
- gltf.buffers[0].byteLength = byteOffset
770
-
771
-
772
- # set actual data
773
- gltf.set_binary_blob(blob)
774
-
775
- # glb = b"".join(gltf.save_to_bytes())
776
- gltf.save(path)
777
-
778
-
779
- def write_obj(self, path):
780
- """write the mesh in obj format. Will also write the texture and mtl files.
781
-
782
- Args:
783
- path (str): path to write.
784
- """
785
-
786
- mtl_path = path.replace(".obj", ".mtl")
787
- albedo_path = path.replace(".obj", "_albedo.png")
788
- metallic_path = path.replace(".obj", "_metallic.png")
789
- roughness_path = path.replace(".obj", "_roughness.png")
790
-
791
- v_np = self.v.detach().cpu().numpy()
792
- vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
793
- vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
794
- f_np = self.f.detach().cpu().numpy()
795
- ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
796
- fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
797
-
798
- with open(path, "w") as fp:
799
- fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
800
-
801
- for v in v_np:
802
- fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
803
-
804
- if vt_np is not None:
805
- for v in vt_np:
806
- fp.write(f"vt {v[0]} {1 - v[1]} \n")
807
-
808
- if vn_np is not None:
809
- for v in vn_np:
810
- fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
811
-
812
- fp.write(f"usemtl defaultMat \n")
813
- for i in range(len(f_np)):
814
- fp.write(
815
- f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
816
- {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
817
- {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
818
- )
819
-
820
- with open(mtl_path, "w") as fp:
821
- fp.write(f"newmtl defaultMat \n")
822
- fp.write(f"Ka 1 1 1 \n")
823
- fp.write(f"Kd 1 1 1 \n")
824
- fp.write(f"Ks 0 0 0 \n")
825
- fp.write(f"Tr 1 \n")
826
- fp.write(f"illum 1 \n")
827
- fp.write(f"Ns 0 \n")
828
- if self.albedo is not None:
829
- fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
830
- if self.metallicRoughness is not None:
831
- # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
832
- fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
833
- fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
834
-
835
- if self.albedo is not None:
836
- albedo = self.albedo.detach().cpu().numpy()
837
- albedo = (albedo * 255).astype(np.uint8)
838
- cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
839
-
840
- if self.metallicRoughness is not None:
841
- metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
842
- metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
843
- cv2.imwrite(metallic_path, metallicRoughness[..., 2])
844
- cv2.imwrite(roughness_path, metallicRoughness[..., 1])
845
-
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ from kiui.op import safe_normalize, dot
8
+ from kiui.typing import *
9
+
10
+ class Mesh:
11
+ """
12
+ A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
13
+
14
+ Note:
15
+ This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
16
+ """
17
+ def __init__(
18
+ self,
19
+ v: Optional[Tensor] = None,
20
+ f: Optional[Tensor] = None,
21
+ vn: Optional[Tensor] = None,
22
+ fn: Optional[Tensor] = None,
23
+ vt: Optional[Tensor] = None,
24
+ ft: Optional[Tensor] = None,
25
+ vc: Optional[Tensor] = None, # vertex color
26
+ albedo: Optional[Tensor] = None,
27
+ metallicRoughness: Optional[Tensor] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ """Init a mesh directly using all attributes.
31
+
32
+ Args:
33
+ v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
34
+ f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
35
+ vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
36
+ fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
37
+ vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
38
+ ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
39
+ vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
40
+ albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
41
+ metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
42
+ device (Optional[torch.device]): torch device. Defaults to None.
43
+ """
44
+ self.device = device
45
+ self.v = v
46
+ self.vn = vn
47
+ self.vt = vt
48
+ self.f = f
49
+ self.fn = fn
50
+ self.ft = ft
51
+ # will first see if there is vertex color to use
52
+ self.vc = vc
53
+ # only support a single albedo image
54
+ self.albedo = albedo
55
+ # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
56
+ # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
57
+ self.metallicRoughness = metallicRoughness
58
+
59
+ self.ori_center = 0
60
+ self.ori_scale = 1
61
+
62
+ @classmethod
63
+ def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
64
+ """load mesh from path.
65
+
66
+ Args:
67
+ path (str): path to mesh file, supports ply, obj, glb.
68
+ clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
69
+ resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
70
+ renormal (bool, optional): re-calc the vertex normals. Defaults to True.
71
+ retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
72
+ bound (float, optional): bound to resize. Defaults to 0.9.
73
+ front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
74
+ device (torch.device, optional): torch device. Defaults to None.
75
+
76
+ Note:
77
+ a ``device`` keyword argument can be provided to specify the torch device.
78
+ If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
79
+
80
+ Returns:
81
+ Mesh: the loaded Mesh object.
82
+ """
83
+ # obj supports face uv
84
+ if path.endswith(".obj"):
85
+ mesh = cls.load_obj(path, **kwargs)
86
+ # trimesh only supports vertex uv, but can load more formats
87
+ else:
88
+ mesh = cls.load_trimesh(path, **kwargs)
89
+
90
+ # clean
91
+ if clean:
92
+ from kiui.mesh_utils import clean_mesh
93
+ vertices = mesh.v.detach().cpu().numpy()
94
+ triangles = mesh.f.detach().cpu().numpy()
95
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
96
+ mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
97
+ mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
98
+
99
+ print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
100
+ # auto-normalize
101
+ if resize:
102
+ mesh.auto_size(bound=bound)
103
+ # auto-fix normal
104
+ if renormal or mesh.vn is None:
105
+ mesh.auto_normal()
106
+ print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
107
+ # auto-fix texcoords
108
+ if retex or (mesh.albedo is not None and mesh.vt is None):
109
+ mesh.auto_uv(cache_path=path)
110
+ print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
111
+
112
+ # rotate front dir to +z
113
+ if front_dir != "+z":
114
+ # axis switch
115
+ if "-z" in front_dir:
116
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
117
+ elif "+x" in front_dir:
118
+ T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
119
+ elif "-x" in front_dir:
120
+ T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
121
+ elif "+y" in front_dir:
122
+ T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
123
+ elif "-y" in front_dir:
124
+ T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
125
+ else:
126
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
127
+ # rotation (how many 90 degrees)
128
+ if '1' in front_dir:
129
+ T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
130
+ elif '2' in front_dir:
131
+ T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
132
+ elif '3' in front_dir:
133
+ T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
134
+ mesh.v @= T
135
+ mesh.vn @= T
136
+
137
+ return mesh
138
+
139
+ # load from obj file
140
+ @classmethod
141
+ def load_obj(cls, path, albedo_path=None, device=None):
142
+ """load an ``obj`` mesh.
143
+
144
+ Args:
145
+ path (str): path to mesh.
146
+ albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
147
+ device (torch.device, optional): torch device. Defaults to None.
148
+
149
+ Note:
150
+ We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
151
+ The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
152
+
153
+ Returns:
154
+ Mesh: the loaded Mesh object.
155
+ """
156
+ assert os.path.splitext(path)[-1] == ".obj"
157
+
158
+ mesh = cls()
159
+
160
+ # device
161
+ if device is None:
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+
164
+ mesh.device = device
165
+
166
+ # load obj
167
+ with open(path, "r") as f:
168
+ lines = f.readlines()
169
+
170
+ def parse_f_v(fv):
171
+ # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
172
+ # supported forms:
173
+ # f v1 v2 v3
174
+ # f v1/vt1 v2/vt2 v3/vt3
175
+ # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
176
+ # f v1//vn1 v2//vn2 v3//vn3
177
+ xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
178
+ xs.extend([-1] * (3 - len(xs)))
179
+ return xs[0], xs[1], xs[2]
180
+
181
+ vertices, texcoords, normals = [], [], []
182
+ faces, tfaces, nfaces = [], [], []
183
+ mtl_path = None
184
+
185
+ for line in lines:
186
+ split_line = line.split()
187
+ # empty line
188
+ if len(split_line) == 0:
189
+ continue
190
+ prefix = split_line[0].lower()
191
+ # mtllib
192
+ if prefix == "mtllib":
193
+ mtl_path = split_line[1]
194
+ # usemtl
195
+ elif prefix == "usemtl":
196
+ pass # ignored
197
+ # v/vn/vt
198
+ elif prefix == "v":
199
+ vertices.append([float(v) for v in split_line[1:]])
200
+ elif prefix == "vn":
201
+ normals.append([float(v) for v in split_line[1:]])
202
+ elif prefix == "vt":
203
+ val = [float(v) for v in split_line[1:]]
204
+ texcoords.append([val[0], 1.0 - val[1]])
205
+ elif prefix == "f":
206
+ vs = split_line[1:]
207
+ nv = len(vs)
208
+ v0, t0, n0 = parse_f_v(vs[0])
209
+ for i in range(nv - 2): # triangulate (assume vertices are ordered)
210
+ v1, t1, n1 = parse_f_v(vs[i + 1])
211
+ v2, t2, n2 = parse_f_v(vs[i + 2])
212
+ faces.append([v0, v1, v2])
213
+ tfaces.append([t0, t1, t2])
214
+ nfaces.append([n0, n1, n2])
215
+
216
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
217
+ mesh.vt = (
218
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
219
+ if len(texcoords) > 0
220
+ else None
221
+ )
222
+ mesh.vn = (
223
+ torch.tensor(normals, dtype=torch.float32, device=device)
224
+ if len(normals) > 0
225
+ else None
226
+ )
227
+
228
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
229
+ mesh.ft = (
230
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
231
+ if len(texcoords) > 0
232
+ else None
233
+ )
234
+ mesh.fn = (
235
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
236
+ if len(normals) > 0
237
+ else None
238
+ )
239
+
240
+ # see if there is vertex color
241
+ use_vertex_color = False
242
+ if mesh.v.shape[1] == 6:
243
+ use_vertex_color = True
244
+ mesh.vc = mesh.v[:, 3:]
245
+ mesh.v = mesh.v[:, :3]
246
+ print(f"[load_obj] use vertex color: {mesh.vc.shape}")
247
+
248
+ # try to load texture image
249
+ if not use_vertex_color:
250
+ # try to retrieve mtl file
251
+ mtl_path_candidates = []
252
+ if mtl_path is not None:
253
+ mtl_path_candidates.append(mtl_path)
254
+ mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
255
+ mtl_path_candidates.append(path.replace(".obj", ".mtl"))
256
+
257
+ mtl_path = None
258
+ for candidate in mtl_path_candidates:
259
+ if os.path.exists(candidate):
260
+ mtl_path = candidate
261
+ break
262
+
263
+ # if albedo_path is not provided, try retrieve it from mtl
264
+ metallic_path = None
265
+ roughness_path = None
266
+ if mtl_path is not None and albedo_path is None:
267
+ with open(mtl_path, "r") as f:
268
+ lines = f.readlines()
269
+
270
+ for line in lines:
271
+ split_line = line.split()
272
+ # empty line
273
+ if len(split_line) == 0:
274
+ continue
275
+ prefix = split_line[0]
276
+
277
+ if "map_Kd" in prefix:
278
+ # assume relative path!
279
+ albedo_path = os.path.join(os.path.dirname(path), split_line[1])
280
+ print(f"[load_obj] use texture from: {albedo_path}")
281
+ elif "map_Pm" in prefix:
282
+ metallic_path = os.path.join(os.path.dirname(path), split_line[1])
283
+ elif "map_Pr" in prefix:
284
+ roughness_path = os.path.join(os.path.dirname(path), split_line[1])
285
+
286
+ # still not found albedo_path, or the path doesn't exist
287
+ if albedo_path is None or not os.path.exists(albedo_path):
288
+ # init an empty texture
289
+ print(f"[load_obj] init empty albedo!")
290
+ # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
291
+ albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
292
+ else:
293
+ albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
294
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
295
+ albedo = albedo.astype(np.float32) / 255
296
+ print(f"[load_obj] load texture: {albedo.shape}")
297
+
298
+ mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
299
+
300
+ # try to load metallic and roughness
301
+ if metallic_path is not None and roughness_path is not None:
302
+ print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
303
+ metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
304
+ metallic = metallic.astype(np.float32) / 255
305
+ roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
306
+ roughness = roughness.astype(np.float32) / 255
307
+ metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
308
+
309
+ mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
310
+
311
+ return mesh
312
+
313
+ @classmethod
314
+ def load_trimesh(cls, path, device=None):
315
+ """load a mesh using ``trimesh.load()``.
316
+
317
+ Can load various formats like ``glb`` and serves as a fallback.
318
+
319
+ Note:
320
+ We will try to merge all meshes if the glb contains more than one,
321
+ but **this may cause the texture to lose**, since we only support one texture image!
322
+
323
+ Args:
324
+ path (str): path to the mesh file.
325
+ device (torch.device, optional): torch device. Defaults to None.
326
+
327
+ Returns:
328
+ Mesh: the loaded Mesh object.
329
+ """
330
+ mesh = cls()
331
+
332
+ # device
333
+ if device is None:
334
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
335
+
336
+ mesh.device = device
337
+
338
+ # use trimesh to load ply/glb
339
+ _data = trimesh.load(path)
340
+ if isinstance(_data, trimesh.Scene):
341
+ if len(_data.geometry) == 1:
342
+ _mesh = list(_data.geometry.values())[0]
343
+ else:
344
+ print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
345
+ _concat = []
346
+ # loop the scene graph and apply transform to each mesh
347
+ scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
348
+ for k, v in scene_graph.items():
349
+ name = v['geometry']
350
+ if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
351
+ transform = v['transform']
352
+ _concat.append(_data.geometry[name].apply_transform(transform))
353
+ _mesh = trimesh.util.concatenate(_concat)
354
+ else:
355
+ _mesh = _data
356
+
357
+ if _mesh.visual.kind == 'vertex':
358
+ vertex_colors = _mesh.visual.vertex_colors
359
+ vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
360
+ mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
361
+ print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
362
+ elif _mesh.visual.kind == 'texture':
363
+ _material = _mesh.visual.material
364
+ if isinstance(_material, trimesh.visual.material.PBRMaterial):
365
+ texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
366
+ # load metallicRoughness if present
367
+ if _material.metallicRoughnessTexture is not None:
368
+ metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
369
+ mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
370
+ elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
371
+ texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
372
+ else:
373
+ raise NotImplementedError(f"material type {type(_material)} not supported!")
374
+ mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
375
+ print(f"[load_trimesh] load texture: {texture.shape}")
376
+ else:
377
+ texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
378
+ mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
379
+ print(f"[load_trimesh] failed to load texture.")
380
+
381
+ vertices = _mesh.vertices
382
+
383
+ try:
384
+ texcoords = _mesh.visual.uv
385
+ texcoords[:, 1] = 1 - texcoords[:, 1]
386
+ except Exception as e:
387
+ texcoords = None
388
+
389
+ try:
390
+ normals = _mesh.vertex_normals
391
+ except Exception as e:
392
+ normals = None
393
+
394
+ # trimesh only support vertex uv...
395
+ faces = tfaces = nfaces = _mesh.faces
396
+
397
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
398
+ mesh.vt = (
399
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
400
+ if texcoords is not None
401
+ else None
402
+ )
403
+ mesh.vn = (
404
+ torch.tensor(normals, dtype=torch.float32, device=device)
405
+ if normals is not None
406
+ else None
407
+ )
408
+
409
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
410
+ mesh.ft = (
411
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
412
+ if texcoords is not None
413
+ else None
414
+ )
415
+ mesh.fn = (
416
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
417
+ if normals is not None
418
+ else None
419
+ )
420
+
421
+ return mesh
422
+
423
+ # sample surface (using trimesh)
424
+ def sample_surface(self, count: int):
425
+ """sample points on the surface of the mesh.
426
+
427
+ Args:
428
+ count (int): number of points to sample.
429
+
430
+ Returns:
431
+ torch.Tensor: the sampled points, float [count, 3].
432
+ """
433
+ _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
434
+ points, face_idx = trimesh.sample.sample_surface(_mesh, count)
435
+ points = torch.from_numpy(points).float().to(self.device)
436
+ return points
437
+
438
+ # aabb
439
+ def aabb(self):
440
+ """get the axis-aligned bounding box of the mesh.
441
+
442
+ Returns:
443
+ Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
444
+ """
445
+ return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
446
+
447
+ # unit size
448
+ @torch.no_grad()
449
+ def auto_size(self, bound=0.9):
450
+ """auto resize the mesh.
451
+
452
+ Args:
453
+ bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
454
+ """
455
+ vmin, vmax = self.aabb()
456
+ self.ori_center = (vmax + vmin) / 2
457
+ self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
458
+ self.v = (self.v - self.ori_center) * self.ori_scale
459
+
460
+ def auto_normal(self):
461
+ """auto calculate the vertex normals.
462
+ """
463
+ i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
464
+ v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
465
+
466
+ face_normals = torch.cross(v1 - v0, v2 - v0)
467
+
468
+ # Splat face normals to vertices
469
+ vn = torch.zeros_like(self.v)
470
+ vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
471
+ vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
472
+ vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
473
+
474
+ # Normalize, replace zero (degenerated) normals with some default value
475
+ vn = torch.where(
476
+ dot(vn, vn) > 1e-20,
477
+ vn,
478
+ torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
479
+ )
480
+ vn = safe_normalize(vn)
481
+
482
+ self.vn = vn
483
+ self.fn = self.f
484
+
485
+ def auto_uv(self, cache_path=None, vmap=True):
486
+ """auto calculate the uv coordinates.
487
+
488
+ Args:
489
+ cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
490
+ vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf).
491
+ Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
492
+ """
493
+ # try to load cache
494
+ if cache_path is not None:
495
+ cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
496
+ if cache_path is not None and os.path.exists(cache_path):
497
+ data = np.load(cache_path)
498
+ vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
499
+ else:
500
+ import xatlas
501
+
502
+ v_np = self.v.detach().cpu().numpy()
503
+ f_np = self.f.detach().int().cpu().numpy()
504
+ atlas = xatlas.Atlas()
505
+ atlas.add_mesh(v_np, f_np)
506
+ chart_options = xatlas.ChartOptions()
507
+ # chart_options.max_iterations = 4
508
+ atlas.generate(chart_options=chart_options)
509
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
510
+
511
+ # save to cache
512
+ if cache_path is not None:
513
+ np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
514
+
515
+ vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
516
+ ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
517
+ self.vt = vt
518
+ self.ft = ft
519
+
520
+ if vmap:
521
+ vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
522
+ self.align_v_to_vt(vmapping)
523
+
524
+ def align_v_to_vt(self, vmapping=None):
525
+ """ remap v/f and vn/fn to vt/ft.
526
+
527
+ Args:
528
+ vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
529
+ """
530
+ if vmapping is None:
531
+ ft = self.ft.view(-1).long()
532
+ f = self.f.view(-1).long()
533
+ vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
534
+ vmapping[ft] = f # scatter, randomly choose one if index is not unique
535
+
536
+ self.v = self.v[vmapping]
537
+ self.f = self.ft
538
+
539
+ if self.vn is not None:
540
+ self.vn = self.vn[vmapping]
541
+ self.fn = self.ft
542
+
543
+ def to(self, device):
544
+ """move all tensor attributes to device.
545
+
546
+ Args:
547
+ device (torch.device): target device.
548
+
549
+ Returns:
550
+ Mesh: self.
551
+ """
552
+ self.device = device
553
+ for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
554
+ tensor = getattr(self, name)
555
+ if tensor is not None:
556
+ setattr(self, name, tensor.to(device))
557
+ return self
558
+
559
+ def write(self, path):
560
+ """write the mesh to a path.
561
+
562
+ Args:
563
+ path (str): path to write, supports ply, obj and glb.
564
+ """
565
+ if path.endswith(".ply"):
566
+ self.write_ply(path)
567
+ elif path.endswith(".obj"):
568
+ self.write_obj(path)
569
+ elif path.endswith(".glb") or path.endswith(".gltf"):
570
+ self.write_glb(path)
571
+ else:
572
+ raise NotImplementedError(f"format {path} not supported!")
573
+
574
+ def write_ply(self, path):
575
+ """write the mesh in ply format. Only for geometry!
576
+
577
+ Args:
578
+ path (str): path to write.
579
+ """
580
+
581
+ if self.albedo is not None:
582
+ print(f'[WARN] ply format does not support exporting texture, will ignore!')
583
+
584
+ v_np = self.v.detach().cpu().numpy()
585
+ f_np = self.f.detach().cpu().numpy()
586
+
587
+ _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
588
+ _mesh.export(path)
589
+
590
+
591
+ def write_glb(self, path):
592
+ """write the mesh in glb/gltf format.
593
+ This will create a scene with a single mesh.
594
+
595
+ Args:
596
+ path (str): path to write.
597
+ """
598
+
599
+ # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
600
+ if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
601
+ self.align_v_to_vt()
602
+
603
+ import pygltflib
604
+
605
+ f_np = self.f.detach().cpu().numpy().astype(np.uint32)
606
+ f_np_blob = f_np.flatten().tobytes()
607
+
608
+ v_np = self.v.detach().cpu().numpy().astype(np.float32)
609
+ v_np_blob = v_np.tobytes()
610
+
611
+ blob = f_np_blob + v_np_blob
612
+ byteOffset = len(blob)
613
+
614
+ # base mesh
615
+ gltf = pygltflib.GLTF2(
616
+ scene=0,
617
+ scenes=[pygltflib.Scene(nodes=[0])],
618
+ nodes=[pygltflib.Node(mesh=0)],
619
+ meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
620
+ # indices to accessors (0 is triangles)
621
+ attributes=pygltflib.Attributes(
622
+ POSITION=1,
623
+ ),
624
+ indices=0,
625
+ )])],
626
+ buffers=[
627
+ pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
628
+ ],
629
+ # buffer view (based on dtype)
630
+ bufferViews=[
631
+ # triangles; as flatten (element) array
632
+ pygltflib.BufferView(
633
+ buffer=0,
634
+ byteLength=len(f_np_blob),
635
+ target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
636
+ ),
637
+ # positions; as vec3 array
638
+ pygltflib.BufferView(
639
+ buffer=0,
640
+ byteOffset=len(f_np_blob),
641
+ byteLength=len(v_np_blob),
642
+ byteStride=12, # vec3
643
+ target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
644
+ ),
645
+ ],
646
+ accessors=[
647
+ # 0 = triangles
648
+ pygltflib.Accessor(
649
+ bufferView=0,
650
+ componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
651
+ count=f_np.size,
652
+ type=pygltflib.SCALAR,
653
+ max=[int(f_np.max())],
654
+ min=[int(f_np.min())],
655
+ ),
656
+ # 1 = positions
657
+ pygltflib.Accessor(
658
+ bufferView=1,
659
+ componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
660
+ count=len(v_np),
661
+ type=pygltflib.VEC3,
662
+ max=v_np.max(axis=0).tolist(),
663
+ min=v_np.min(axis=0).tolist(),
664
+ ),
665
+ ],
666
+ )
667
+
668
+ # append texture info
669
+ if self.vt is not None:
670
+
671
+ vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
672
+ vt_np_blob = vt_np.tobytes()
673
+
674
+ albedo = self.albedo.detach().cpu().numpy()
675
+ albedo = (albedo * 255).astype(np.uint8)
676
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
677
+ albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
678
+
679
+ # update primitive
680
+ gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
681
+ gltf.meshes[0].primitives[0].material = 0
682
+
683
+ # update materials
684
+ gltf.materials.append(pygltflib.Material(
685
+ pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
686
+ baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
687
+ metallicFactor=0.0,
688
+ roughnessFactor=1.0,
689
+ ),
690
+ alphaMode=pygltflib.OPAQUE,
691
+ alphaCutoff=None,
692
+ doubleSided=True,
693
+ ))
694
+
695
+ gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
696
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
697
+ gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
698
+
699
+ # update buffers
700
+ gltf.bufferViews.append(
701
+ # index = 2, texcoords; as vec2 array
702
+ pygltflib.BufferView(
703
+ buffer=0,
704
+ byteOffset=byteOffset,
705
+ byteLength=len(vt_np_blob),
706
+ byteStride=8, # vec2
707
+ target=pygltflib.ARRAY_BUFFER,
708
+ )
709
+ )
710
+
711
+ gltf.accessors.append(
712
+ # 2 = texcoords
713
+ pygltflib.Accessor(
714
+ bufferView=2,
715
+ componentType=pygltflib.FLOAT,
716
+ count=len(vt_np),
717
+ type=pygltflib.VEC2,
718
+ max=vt_np.max(axis=0).tolist(),
719
+ min=vt_np.min(axis=0).tolist(),
720
+ )
721
+ )
722
+
723
+ blob += vt_np_blob
724
+ byteOffset += len(vt_np_blob)
725
+
726
+ gltf.bufferViews.append(
727
+ # index = 3, albedo texture; as none target
728
+ pygltflib.BufferView(
729
+ buffer=0,
730
+ byteOffset=byteOffset,
731
+ byteLength=len(albedo_blob),
732
+ )
733
+ )
734
+
735
+ blob += albedo_blob
736
+ byteOffset += len(albedo_blob)
737
+
738
+ gltf.buffers[0].byteLength = byteOffset
739
+
740
+ # append metllic roughness
741
+ if self.metallicRoughness is not None:
742
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
743
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
744
+ metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
745
+ metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
746
+
747
+ # update texture definition
748
+ gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
749
+ gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
750
+ gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
751
+
752
+ gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
753
+ gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
754
+ gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
755
+
756
+ # update buffers
757
+ gltf.bufferViews.append(
758
+ # index = 4, metallicRoughness texture; as none target
759
+ pygltflib.BufferView(
760
+ buffer=0,
761
+ byteOffset=byteOffset,
762
+ byteLength=len(metallicRoughness_blob),
763
+ )
764
+ )
765
+
766
+ blob += metallicRoughness_blob
767
+ byteOffset += len(metallicRoughness_blob)
768
+
769
+ gltf.buffers[0].byteLength = byteOffset
770
+
771
+
772
+ # set actual data
773
+ gltf.set_binary_blob(blob)
774
+
775
+ # glb = b"".join(gltf.save_to_bytes())
776
+ gltf.save(path)
777
+
778
+
779
+ def write_obj(self, path):
780
+ """write the mesh in obj format. Will also write the texture and mtl files.
781
+
782
+ Args:
783
+ path (str): path to write.
784
+ """
785
+
786
+ mtl_path = path.replace(".obj", ".mtl")
787
+ albedo_path = path.replace(".obj", "_albedo.png")
788
+ metallic_path = path.replace(".obj", "_metallic.png")
789
+ roughness_path = path.replace(".obj", "_roughness.png")
790
+
791
+ v_np = self.v.detach().cpu().numpy()
792
+ vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
793
+ vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
794
+ f_np = self.f.detach().cpu().numpy()
795
+ ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
796
+ fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
797
+
798
+ with open(path, "w") as fp:
799
+ fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
800
+
801
+ for v in v_np:
802
+ fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
803
+
804
+ if vt_np is not None:
805
+ for v in vt_np:
806
+ fp.write(f"vt {v[0]} {1 - v[1]} \n")
807
+
808
+ if vn_np is not None:
809
+ for v in vn_np:
810
+ fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
811
+
812
+ fp.write(f"usemtl defaultMat \n")
813
+ for i in range(len(f_np)):
814
+ fp.write(
815
+ f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
816
+ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
817
+ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
818
+ )
819
+
820
+ with open(mtl_path, "w") as fp:
821
+ fp.write(f"newmtl defaultMat \n")
822
+ fp.write(f"Ka 1 1 1 \n")
823
+ fp.write(f"Kd 1 1 1 \n")
824
+ fp.write(f"Ks 0 0 0 \n")
825
+ fp.write(f"Tr 1 \n")
826
+ fp.write(f"illum 1 \n")
827
+ fp.write(f"Ns 0 \n")
828
+ if self.albedo is not None:
829
+ fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
830
+ if self.metallicRoughness is not None:
831
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
832
+ fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
833
+ fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
834
+
835
+ if self.albedo is not None:
836
+ albedo = self.albedo.detach().cpu().numpy()
837
+ albedo = (albedo * 255).astype(np.uint8)
838
+ cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
839
+
840
+ if self.metallicRoughness is not None:
841
+ metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
842
+ metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
843
+ cv2.imwrite(metallic_path, metallicRoughness[..., 2])
844
+ cv2.imwrite(roughness_path, metallicRoughness[..., 1])
845
+