|
import argparse |
|
import os |
|
import sys |
|
from glob import glob |
|
from typing import Any, Union |
|
|
|
import numpy as np |
|
import torch |
|
import trimesh |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image |
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
from triposg.pipelines.pipeline_triposg import TripoSGPipeline |
|
from image_process import prepare_image |
|
from briarmbg import BriaRMBG |
|
|
|
import pymeshlab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def run_triposg( |
|
pipe: Any, |
|
image_input: Union[str, Image.Image], |
|
rmbg_net: Any, |
|
seed: int, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.0, |
|
faces: int = -1, |
|
) -> trimesh.Scene: |
|
print("[DEBUG] Preparing image...") |
|
img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net) |
|
|
|
print("[DEBUG] Running TripoSG pipeline...") |
|
outputs = pipe( |
|
image=img_pil, |
|
generator=torch.Generator(device=pipe.device).manual_seed(seed), |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
).samples[0] |
|
|
|
print("[DEBUG] TripoSG output keys:", type(outputs), outputs[0].shape, outputs[1].shape) |
|
|
|
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1])) |
|
print(f"[DEBUG] Mesh created: {mesh.vertices.shape[0]} verts / {mesh.faces.shape[0]} faces") |
|
|
|
if faces > 0: |
|
print(f"[DEBUG] Simplifying mesh to {faces} faces") |
|
|
|
|
|
return mesh |
|
|
|
|
|
def mesh_to_pymesh(vertices, faces): |
|
mesh = pymeshlab.Mesh(vertex_matrix=vertices, face_matrix=faces) |
|
ms = pymeshlab.MeshSet() |
|
ms.add_mesh(mesh) |
|
return ms |
|
|
|
def pymesh_to_trimesh(mesh): |
|
verts = mesh.vertex_matrix() |
|
faces = mesh.face_matrix() |
|
return trimesh.Trimesh(vertices=verts, faces=faces) |
|
|
|
def simplify_mesh(mesh: trimesh.Trimesh, n_faces): |
|
if mesh.faces.shape[0] > n_faces: |
|
ms = mesh_to_pymesh(mesh.vertices, mesh.faces) |
|
ms.meshing_merge_close_vertices() |
|
ms.meshing_decimation_quadric_edge_collapse(targetfacenum = n_faces) |
|
return pymesh_to_trimesh(ms.current_mesh()) |
|
else: |
|
return mesh |
|
|
|
if __name__ == "__main__": |
|
device = "cuda" |
|
dtype = torch.float16 |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--image-input", type=str, required=True) |
|
parser.add_argument("--output-path", type=str, default="./output.glb") |
|
parser.add_argument("--seed", type=int, default=42) |
|
parser.add_argument("--num-inference-steps", type=int, default=50) |
|
parser.add_argument("--guidance-scale", type=float, default=7.0) |
|
parser.add_argument("--faces", type=int, default=-1) |
|
args = parser.parse_args() |
|
|
|
|
|
triposg_weights_dir = "pretrained_weights/TripoSG" |
|
rmbg_weights_dir = "pretrained_weights/RMBG-1.4" |
|
snapshot_download(repo_id="VAST-AI/TripoSG", local_dir=triposg_weights_dir) |
|
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir) |
|
|
|
|
|
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(device) |
|
rmbg_net.eval() |
|
|
|
|
|
pipe: TripoSGPipeline = TripoSGPipeline.from_pretrained(triposg_weights_dir).to(device, dtype) |
|
|
|
|
|
run_triposg( |
|
pipe, |
|
image_input=args.image_input, |
|
rmbg_net=rmbg_net, |
|
seed=args.seed, |
|
num_inference_steps=args.num_inference_steps, |
|
guidance_scale=args.guidance_scale, |
|
faces=args.faces, |
|
).export(args.output_path) |
|
print(f"Mesh saved to {args.output_path}") |
|
|