Mariam-Elz commited on
Commit
6855039
·
verified ·
1 Parent(s): 7b5c5ef

Upload util/flexicubes_geometry.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. util/flexicubes_geometry.py +116 -116
util/flexicubes_geometry.py CHANGED
@@ -1,116 +1,116 @@
1
- # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
-
9
- import torch
10
- from util.flexicubes import FlexiCubes # replace later
11
- # from dmtet import sdf_reg_loss_batch
12
- import torch.nn.functional as F
13
-
14
- def get_center_boundary_index(grid_res, device):
15
- v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
16
- v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
17
- center_indices = torch.nonzero(v.reshape(-1))
18
-
19
- v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
20
- v[:2, ...] = True
21
- v[-2:, ...] = True
22
- v[:, :2, ...] = True
23
- v[:, -2:, ...] = True
24
- v[:, :, :2] = True
25
- v[:, :, -2:] = True
26
- boundary_indices = torch.nonzero(v.reshape(-1))
27
- return center_indices, boundary_indices
28
-
29
- ###############################################################################
30
- # Geometry interface
31
- ###############################################################################
32
- class FlexiCubesGeometry(object):
33
- def __init__(
34
- self, grid_res=64, scale=2.0, device='cuda', renderer=None,
35
- render_type='neural_render', args=None):
36
- super(FlexiCubesGeometry, self).__init__()
37
- self.grid_res = grid_res
38
- self.device = device
39
- self.args = args
40
- self.fc = FlexiCubes(device, weight_scale=0.5)
41
- self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
42
- if isinstance(scale, list):
43
- self.verts[:, 0] = self.verts[:, 0] * scale[0]
44
- self.verts[:, 1] = self.verts[:, 1] * scale[1]
45
- self.verts[:, 2] = self.verts[:, 2] * scale[1]
46
- else:
47
- self.verts = self.verts * scale
48
-
49
- all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
50
- self.all_edges = torch.unique(all_edges, dim=0)
51
-
52
- # Parameters used for fix boundary sdf
53
- self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
54
- self.renderer = renderer
55
- self.render_type = render_type
56
-
57
- def getAABB(self):
58
- return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
59
-
60
- def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
61
- if indices is None:
62
- indices = self.indices
63
-
64
- verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
65
- beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
66
- gamma_f=weight_n[:, 20], training=is_training
67
- )
68
- return verts, faces, v_reg_loss
69
-
70
-
71
- def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
72
- return_value = dict()
73
- if self.render_type == 'neural_render':
74
- tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
75
- mesh_v_nx3.unsqueeze(dim=0),
76
- mesh_f_fx3.int(),
77
- camera_mv_bx4x4,
78
- mesh_v_nx3.unsqueeze(dim=0),
79
- resolution=resolution,
80
- device=self.device,
81
- hierarchical_mask=hierarchical_mask
82
- )
83
-
84
- return_value['tex_pos'] = tex_pos
85
- return_value['mask'] = mask
86
- return_value['hard_mask'] = hard_mask
87
- return_value['rast'] = rast
88
- return_value['v_pos_clip'] = v_pos_clip
89
- return_value['mask_pyramid'] = mask_pyramid
90
- return_value['depth'] = depth
91
- else:
92
- raise NotImplementedError
93
-
94
- return return_value
95
-
96
- def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
97
- # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
98
- v_list = []
99
- f_list = []
100
- n_batch = v_deformed_bxnx3.shape[0]
101
- all_render_output = []
102
- for i_batch in range(n_batch):
103
- verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
104
- v_list.append(verts_nx3)
105
- f_list.append(faces_fx3)
106
- render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
107
- all_render_output.append(render_output)
108
-
109
- # Concatenate all render output
110
- return_keys = all_render_output[0].keys()
111
- return_value = dict()
112
- for k in return_keys:
113
- value = [v[k] for v in all_render_output]
114
- return_value[k] = value
115
- # We can do concatenation outside of the render
116
- return return_value
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from util.flexicubes import FlexiCubes # replace later
11
+ # from dmtet import sdf_reg_loss_batch
12
+ import torch.nn.functional as F
13
+
14
+ def get_center_boundary_index(grid_res, device):
15
+ v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
16
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
17
+ center_indices = torch.nonzero(v.reshape(-1))
18
+
19
+ v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
20
+ v[:2, ...] = True
21
+ v[-2:, ...] = True
22
+ v[:, :2, ...] = True
23
+ v[:, -2:, ...] = True
24
+ v[:, :, :2] = True
25
+ v[:, :, -2:] = True
26
+ boundary_indices = torch.nonzero(v.reshape(-1))
27
+ return center_indices, boundary_indices
28
+
29
+ ###############################################################################
30
+ # Geometry interface
31
+ ###############################################################################
32
+ class FlexiCubesGeometry(object):
33
+ def __init__(
34
+ self, grid_res=64, scale=2.0, device='cuda', renderer=None,
35
+ render_type='neural_render', args=None):
36
+ super(FlexiCubesGeometry, self).__init__()
37
+ self.grid_res = grid_res
38
+ self.device = device
39
+ self.args = args
40
+ self.fc = FlexiCubes(device, weight_scale=0.5)
41
+ self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
42
+ if isinstance(scale, list):
43
+ self.verts[:, 0] = self.verts[:, 0] * scale[0]
44
+ self.verts[:, 1] = self.verts[:, 1] * scale[1]
45
+ self.verts[:, 2] = self.verts[:, 2] * scale[1]
46
+ else:
47
+ self.verts = self.verts * scale
48
+
49
+ all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
50
+ self.all_edges = torch.unique(all_edges, dim=0)
51
+
52
+ # Parameters used for fix boundary sdf
53
+ self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
54
+ self.renderer = renderer
55
+ self.render_type = render_type
56
+
57
+ def getAABB(self):
58
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
59
+
60
+ def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
61
+ if indices is None:
62
+ indices = self.indices
63
+
64
+ verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
65
+ beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
66
+ gamma_f=weight_n[:, 20], training=is_training
67
+ )
68
+ return verts, faces, v_reg_loss
69
+
70
+
71
+ def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
72
+ return_value = dict()
73
+ if self.render_type == 'neural_render':
74
+ tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
75
+ mesh_v_nx3.unsqueeze(dim=0),
76
+ mesh_f_fx3.int(),
77
+ camera_mv_bx4x4,
78
+ mesh_v_nx3.unsqueeze(dim=0),
79
+ resolution=resolution,
80
+ device=self.device,
81
+ hierarchical_mask=hierarchical_mask
82
+ )
83
+
84
+ return_value['tex_pos'] = tex_pos
85
+ return_value['mask'] = mask
86
+ return_value['hard_mask'] = hard_mask
87
+ return_value['rast'] = rast
88
+ return_value['v_pos_clip'] = v_pos_clip
89
+ return_value['mask_pyramid'] = mask_pyramid
90
+ return_value['depth'] = depth
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ return return_value
95
+
96
+ def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
97
+ # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
98
+ v_list = []
99
+ f_list = []
100
+ n_batch = v_deformed_bxnx3.shape[0]
101
+ all_render_output = []
102
+ for i_batch in range(n_batch):
103
+ verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
104
+ v_list.append(verts_nx3)
105
+ f_list.append(faces_fx3)
106
+ render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
107
+ all_render_output.append(render_output)
108
+
109
+ # Concatenate all render output
110
+ return_keys = all_render_output[0].keys()
111
+ return_value = dict()
112
+ for k in return_keys:
113
+ value = [v[k] for v in all_render_output]
114
+ return_value[k] = value
115
+ # We can do concatenation outside of the render
116
+ return return_value