Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import shutil | |
| from dataclasses import dataclass | |
| import tyro | |
| from embodied_gen.data.backproject_v2 import entrypoint as backproject_api | |
| from embodied_gen.data.differentiable_render import entrypoint as drender_api | |
| from embodied_gen.data.utils import as_list | |
| from embodied_gen.models.delight_model import DelightingModel | |
| from embodied_gen.models.sr_model import ImageRealESRGAN | |
| from embodied_gen.scripts.render_mv import ( | |
| build_texture_gen_pipe, | |
| ) | |
| from embodied_gen.scripts.render_mv import infer_pipe as render_mv_api | |
| from embodied_gen.utils.log import logger | |
| class TextureGenConfig: | |
| mesh_path: str | list[str] | |
| prompt: str | list[str] | |
| output_root: str | |
| controlnet_cond_scale: float = 0.7 | |
| guidance_scale: float = 9 | |
| strength: float = 0.9 | |
| num_inference_steps: int = 40 | |
| delight: bool = True | |
| seed: int = 0 | |
| base_ckpt_dir: str = "./weights" | |
| texture_size: int = 2048 | |
| ip_adapt_scale: float = 0.0 | |
| ip_img_path: str | list[str] | None = None | |
| def entrypoint() -> None: | |
| cfg = tyro.cli(TextureGenConfig) | |
| cfg.mesh_path = as_list(cfg.mesh_path) | |
| cfg.prompt = as_list(cfg.prompt) | |
| cfg.ip_img_path = as_list(cfg.ip_img_path) | |
| assert len(cfg.mesh_path) == len(cfg.prompt) | |
| # Pre-load models. | |
| if cfg.ip_adapt_scale > 0: | |
| PIPELINE = build_texture_gen_pipe( | |
| base_ckpt_dir="./weights", | |
| ip_adapt_scale=cfg.ip_adapt_scale, | |
| device="cuda", | |
| ) | |
| else: | |
| PIPELINE = build_texture_gen_pipe( | |
| base_ckpt_dir="./weights", | |
| ip_adapt_scale=0, | |
| device="cuda", | |
| ) | |
| DELIGHT = None | |
| if cfg.delight: | |
| DELIGHT = DelightingModel() | |
| IMAGESR_MODEL = ImageRealESRGAN(outscale=4) | |
| for idx in range(len(cfg.mesh_path)): | |
| mesh_path = cfg.mesh_path[idx] | |
| prompt = cfg.prompt[idx] | |
| uuid = os.path.splitext(os.path.basename(mesh_path))[0] | |
| output_root = os.path.join(cfg.output_root, uuid) | |
| drender_api( | |
| mesh_path=mesh_path, | |
| output_root=f"{output_root}/condition", | |
| uuid=uuid, | |
| ) | |
| render_mv_api( | |
| index_file=f"{output_root}/condition/index.json", | |
| controlnet_cond_scale=cfg.controlnet_cond_scale, | |
| guidance_scale=cfg.guidance_scale, | |
| strength=cfg.strength, | |
| num_inference_steps=cfg.num_inference_steps, | |
| ip_adapt_scale=cfg.ip_adapt_scale, | |
| ip_img_path=( | |
| None if cfg.ip_img_path is None else cfg.ip_img_path[idx] | |
| ), | |
| prompt=prompt, | |
| save_dir=f"{output_root}/multi_view", | |
| sub_idxs=[[0, 1, 2], [3, 4, 5]], | |
| pipeline=PIPELINE, | |
| seed=cfg.seed, | |
| ) | |
| textured_mesh = backproject_api( | |
| delight_model=DELIGHT, | |
| imagesr_model=IMAGESR_MODEL, | |
| mesh_path=mesh_path, | |
| color_path=f"{output_root}/multi_view/color_sample0.png", | |
| output_path=f"{output_root}/texture_mesh/{uuid}.obj", | |
| save_glb_path=f"{output_root}/texture_mesh/{uuid}.glb", | |
| skip_fix_mesh=True, | |
| delight=cfg.delight, | |
| no_save_delight_img=True, | |
| texture_wh=[cfg.texture_size, cfg.texture_size], | |
| ) | |
| drender_api( | |
| mesh_path=f"{output_root}/texture_mesh/{uuid}.obj", | |
| output_root=f"{output_root}/texture_mesh", | |
| uuid=uuid, | |
| num_images=90, | |
| elevation=[20], | |
| with_mtl=True, | |
| gen_color_mp4=True, | |
| pbr_light_factor=1.2, | |
| ) | |
| # Re-organize folders | |
| shutil.rmtree(f"{output_root}/condition") | |
| shutil.copy( | |
| f"{output_root}/texture_mesh/{uuid}/color.mp4", | |
| f"{output_root}/color.mp4", | |
| ) | |
| shutil.rmtree(f"{output_root}/texture_mesh/{uuid}") | |
| logger.info( | |
| f"Successfully generate textured mesh in {output_root}/texture_mesh" | |
| ) | |
| if __name__ == "__main__": | |
| entrypoint() | |