ginipick commited on
Commit
c9a47aa
·
verified ·
1 Parent(s): caeb28c

Update scripts/utils.py

Browse files
Files changed (1) hide show
  1. scripts/utils.py +67 -27
scripts/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import numpy as np
3
  from PIL import Image
@@ -7,22 +9,42 @@ from pymeshlab import PercentageValue
7
  from pytorch3d.renderer import TexturesVertex
8
  from pytorch3d.structures import Meshes
9
  from rembg import new_session, remove
10
- import torch
11
  import torch.nn.functional as F
12
  from typing import List, Tuple
13
- from PIL import Image
14
  import trimesh
15
 
16
- providers = [
17
- ('CUDAExecutionProvider', {
18
- 'device_id': 0,
19
- 'arena_extend_strategy': 'kSameAsRequested',
20
- 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
21
- 'cudnn_conv_algo_search': 'HEURISTIC',
22
- })
23
- ]
24
 
25
- session = new_session(providers=providers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)"
28
 
@@ -62,7 +84,6 @@ def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
62
  textures = TexturesVertex(verts_features=[colors])
63
  return Meshes(verts=[verts], faces=[faces], textures=textures)
64
 
65
-
66
  def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
67
  colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
68
  m1 = pymeshlab.Mesh(
@@ -72,7 +93,6 @@ def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
72
  v_color_matrix=colors_in)
73
  return m1
74
 
75
-
76
  def to_pyml_mesh(vertices,faces):
77
  m1 = pymeshlab.Mesh(
78
  vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
@@ -80,7 +100,6 @@ def to_pyml_mesh(vertices,faces):
80
  )
81
  return m1
82
 
83
-
84
  def to_py3d_mesh(vertices, faces, normals=None):
85
  from pytorch3d.structures import Meshes
86
  from pytorch3d.renderer.mesh.textures import TexturesVertex
@@ -91,7 +110,6 @@ def to_py3d_mesh(vertices, faces, normals=None):
91
  mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
92
  return mesh
93
 
94
-
95
  def from_py3d_mesh(mesh):
96
  return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
97
 
@@ -126,7 +144,6 @@ def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.nda
126
  raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
127
  return ret
128
 
129
-
130
  def rotate_normalmap_by_angle_torch(normal_map, angle):
131
  """
132
  rotate along y-axis
@@ -140,7 +157,9 @@ def rotate_normalmap_by_angle_torch(normal_map, angle):
140
  return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
141
 
142
  def do_rotate(rgba_normal, angle):
143
- rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
 
 
144
  rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
145
  rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
146
  rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
@@ -195,7 +214,6 @@ def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
195
  ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
196
  return ret
197
 
198
-
199
  def fix_vert_color_glb(mesh_path):
200
  from pygltflib import GLTF2, Material, PbrMetallicRoughness
201
  obj1 = GLTF2().load(mesh_path)
@@ -211,12 +229,10 @@ def fix_vert_color_glb(mesh_path):
211
  ))
212
  obj1.save(mesh_path)
213
 
214
-
215
  def srgb_to_linear(c_srgb):
216
  c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
217
  return c_linear.clip(0, 1.)
218
 
219
-
220
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
221
  # convert from pytorch3d meshes to trimesh mesh
222
  vertices = meshes.verts_packed().cpu().float().numpy()
@@ -239,7 +255,6 @@ def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to
239
  fix_vert_color_glb(save_glb_path)
240
  print(f"saving to {save_glb_path}")
241
 
242
-
243
  def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
244
  import time
245
  if '.' in save_mesh_prefix:
@@ -251,7 +266,6 @@ def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=Tru
251
  save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
252
  return ret_mesh, None
253
 
254
-
255
  def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
256
  ms = ml.MeshSet()
257
  ms.add_mesh(pyml_mesh, "cube_mesh")
@@ -264,7 +278,6 @@ def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, ap
264
  ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
265
  return meshlab_mesh_to_py3dmesh(ms.current_mesh())
266
 
267
-
268
  def expand2square(pil_img, background_color):
269
  width, height = pil_img.size
270
  if width == height:
@@ -278,13 +291,37 @@ def expand2square(pil_img, background_color):
278
  result.paste(pil_img, ((height - width) // 2, 0))
279
  return result
280
 
281
-
282
- def simple_preprocess(input_image, rembg_session=session, background_color=255):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  RES = 2048
284
  input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
285
  if input_image.mode != 'RGBA':
286
  image_rem = input_image.convert('RGBA')
287
- input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
 
 
 
 
 
 
 
288
 
289
  arr = np.asarray(input_image)
290
  alpha = np.asarray(input_image)[:, :, -1]
@@ -299,7 +336,10 @@ def simple_preprocess(input_image, rembg_session=session, background_color=255):
299
  input_image = expand2square(input_image, (background_color, background_color, background_color, 0))
300
  return input_image
301
 
302
- def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
 
 
 
303
  # Convert the background color to a PyTorch tensor
304
  new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
305
 
 
1
+ import os
2
+ import spaces
3
  import torch
4
  import numpy as np
5
  from PIL import Image
 
9
  from pytorch3d.renderer import TexturesVertex
10
  from pytorch3d.structures import Meshes
11
  from rembg import new_session, remove
 
12
  import torch.nn.functional as F
13
  from typing import List, Tuple
 
14
  import trimesh
15
 
16
+ # ZeroGPU 환경 감지
17
+ IS_ZEROGPU = os.environ.get('SPACE_ID') is not None or os.environ.get('ZEROGPU') is not None
 
 
 
 
 
 
18
 
19
+ # 전역 변수로 session 선언 (초기에는 None)
20
+ _session = None
21
+ _gpu_session = None
22
+
23
+ def get_providers():
24
+ """환경에 따른 적절한 providers 반환"""
25
+ if IS_ZEROGPU:
26
+ # ZeroGPU 환경에서는 초기에 CPU만 사용
27
+ return ['CPUExecutionProvider']
28
+ else:
29
+ # 일반 환경에서는 CUDA 우선 사용
30
+ return [
31
+ ('CUDAExecutionProvider', {
32
+ 'device_id': 0,
33
+ 'arena_extend_strategy': 'kSameAsRequested',
34
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
35
+ 'cudnn_conv_algo_search': 'HEURISTIC',
36
+ })
37
+ ]
38
+
39
+ def get_session():
40
+ """세션을 lazy loading으로 생성"""
41
+ global _session
42
+ if _session is None:
43
+ _session = new_session(providers=get_providers())
44
+ return _session
45
+
46
+ # 기존 코드와의 호환성을 위한 session 변수
47
+ session = None # 초기에는 None, 필요시 get_session() 사용
48
 
49
  NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)"
50
 
 
84
  textures = TexturesVertex(verts_features=[colors])
85
  return Meshes(verts=[verts], faces=[faces], textures=textures)
86
 
 
87
  def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
88
  colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
89
  m1 = pymeshlab.Mesh(
 
93
  v_color_matrix=colors_in)
94
  return m1
95
 
 
96
  def to_pyml_mesh(vertices,faces):
97
  m1 = pymeshlab.Mesh(
98
  vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
 
100
  )
101
  return m1
102
 
 
103
  def to_py3d_mesh(vertices, faces, normals=None):
104
  from pytorch3d.structures import Meshes
105
  from pytorch3d.renderer.mesh.textures import TexturesVertex
 
110
  mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
111
  return mesh
112
 
 
113
  def from_py3d_mesh(mesh):
114
  return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
115
 
 
144
  raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
145
  return ret
146
 
 
147
  def rotate_normalmap_by_angle_torch(normal_map, angle):
148
  """
149
  rotate along y-axis
 
157
  return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
158
 
159
  def do_rotate(rgba_normal, angle):
160
+ # GPU 사용 가능 여부 확인
161
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
162
+ rgba_normal = torch.from_numpy(rgba_normal).float().to(device) / 255
163
  rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
164
  rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
165
  rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
 
214
  ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
215
  return ret
216
 
 
217
  def fix_vert_color_glb(mesh_path):
218
  from pygltflib import GLTF2, Material, PbrMetallicRoughness
219
  obj1 = GLTF2().load(mesh_path)
 
229
  ))
230
  obj1.save(mesh_path)
231
 
 
232
  def srgb_to_linear(c_srgb):
233
  c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
234
  return c_linear.clip(0, 1.)
235
 
 
236
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
237
  # convert from pytorch3d meshes to trimesh mesh
238
  vertices = meshes.verts_packed().cpu().float().numpy()
 
255
  fix_vert_color_glb(save_glb_path)
256
  print(f"saving to {save_glb_path}")
257
 
 
258
  def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
259
  import time
260
  if '.' in save_mesh_prefix:
 
266
  save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
267
  return ret_mesh, None
268
 
 
269
  def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
270
  ms = ml.MeshSet()
271
  ms.add_mesh(pyml_mesh, "cube_mesh")
 
278
  ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
279
  return meshlab_mesh_to_py3dmesh(ms.current_mesh())
280
 
 
281
  def expand2square(pil_img, background_color):
282
  width, height = pil_img.size
283
  if width == height:
 
291
  result.paste(pil_img, ((height - width) // 2, 0))
292
  return result
293
 
294
+ # ZeroGPU용 배경 제거 함수
295
+ @spaces.GPU(duration=30)
296
+ def remove_background_gpu(input_image, alpha_matting=False):
297
+ """GPU에서 배경 제거 실행"""
298
+ global _gpu_session
299
+ if _gpu_session is None:
300
+ # GPU가 할당되면 CUDA 프로바이더로 새 세션 생성
301
+ gpu_providers = [
302
+ ('CUDAExecutionProvider', {
303
+ 'device_id': 0,
304
+ 'arena_extend_strategy': 'kSameAsRequested',
305
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
306
+ 'cudnn_conv_algo_search': 'HEURISTIC',
307
+ })
308
+ ]
309
+ _gpu_session = new_session(providers=gpu_providers)
310
+ return remove(input_image, alpha_matting=alpha_matting, session=_gpu_session)
311
+
312
+ def simple_preprocess(input_image, rembg_session=None, background_color=255):
313
  RES = 2048
314
  input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
315
  if input_image.mode != 'RGBA':
316
  image_rem = input_image.convert('RGBA')
317
+ # ZeroGPU 환경에서는 GPU 함수 사용
318
+ if IS_ZEROGPU:
319
+ input_image = remove_background_gpu(image_rem, alpha_matting=False)
320
+ else:
321
+ # 일반 환경에서는 세션 사용
322
+ if rembg_session is None:
323
+ rembg_session = get_session()
324
+ input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
325
 
326
  arr = np.asarray(input_image)
327
  alpha = np.asarray(input_image)[:, :, -1]
 
336
  input_image = expand2square(input_image, (background_color, background_color, background_color, 0))
337
  return input_image
338
 
339
+ def init_target(img_pils, new_bkgd=(0., 0., 0.), device=None):
340
+ if device is None:
341
+ device = "cuda" if torch.cuda.is_available() else "cpu"
342
+
343
  # Convert the background color to a PyTorch tensor
344
  new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
345