Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import os | |
| import random | |
| import time | |
| import warnings | |
| from dataclasses import dataclass, field | |
| from shutil import copy, rmtree | |
| import torch | |
| import tyro | |
| from huggingface_hub import snapshot_download | |
| from packaging import version | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| logging.getLogger("diffusers").setLevel(logging.ERROR) | |
| # TorchVision monkey patch for >0.16 | |
| if version.parse(torch.__version__) >= version.parse("0.16"): | |
| import sys | |
| import types | |
| import torchvision.transforms.functional as TF | |
| functional_tensor = types.ModuleType( | |
| "torchvision.transforms.functional_tensor" | |
| ) | |
| functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale | |
| sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor | |
| from gsplat.distributed import cli | |
| from txt2panoimg import Text2360PanoramaImagePipeline | |
| from embodied_gen.trainer.gsplat_trainer import ( | |
| DefaultStrategy, | |
| GsplatTrainConfig, | |
| ) | |
| from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint | |
| from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline | |
| from embodied_gen.utils.config import Pano2MeshSRConfig | |
| from embodied_gen.utils.gaussian import restore_scene_scale_and_position | |
| from embodied_gen.utils.gpt_clients import GPT_CLIENT | |
| from embodied_gen.utils.log import logger | |
| from embodied_gen.utils.process_media import is_image_file, parse_text_prompts | |
| from embodied_gen.validators.quality_checkers import ( | |
| PanoHeightEstimator, | |
| PanoImageOccChecker, | |
| ) | |
| __all__ = [ | |
| "generate_pano_image", | |
| "entrypoint", | |
| ] | |
| class Scene3DGenConfig: | |
| prompts: list[str] # Text desc of indoor room or style reference image. | |
| output_dir: str | |
| seed: int | None = None | |
| real_height: float | None = None # The real height of the room in meters. | |
| pano_image_only: bool = False | |
| disable_pano_check: bool = False | |
| keep_middle_result: bool = False | |
| n_retry: int = 7 | |
| gs3d: GsplatTrainConfig = field( | |
| default_factory=lambda: GsplatTrainConfig( | |
| strategy=DefaultStrategy(verbose=True), | |
| max_steps=4000, | |
| init_opa=0.9, | |
| opacity_reg=2e-3, | |
| sh_degree=0, | |
| means_lr=1e-4, | |
| scales_lr=1e-3, | |
| ) | |
| ) | |
| def generate_pano_image( | |
| prompt: str, | |
| output_path: str, | |
| pipeline, | |
| seed: int, | |
| n_retry: int, | |
| checker=None, | |
| num_inference_steps: int = 40, | |
| ) -> None: | |
| for i in range(n_retry): | |
| logger.info( | |
| f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}" | |
| ) | |
| if is_image_file(prompt): | |
| raise NotImplementedError("Image mode not implemented yet.") | |
| else: | |
| txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture" | |
| inputs = { | |
| "prompt": txt_prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "upscale": False, | |
| "seed": seed, | |
| } | |
| pano_image = pipeline(inputs) | |
| pano_image.save(output_path) | |
| if checker is None: | |
| break | |
| flag, response = checker(pano_image) | |
| logger.warning(f"{response}, image saved in {output_path}") | |
| if flag is True or flag is None: | |
| break | |
| seed = random.randint(0, 100000) | |
| return | |
| def entrypoint(*args, **kwargs): | |
| cfg = tyro.cli(Scene3DGenConfig) | |
| # Init global models. | |
| model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage") | |
| IMG2PANO_PIPE = Text2360PanoramaImagePipeline( | |
| model_path, torch_dtype=torch.float16, device="cuda" | |
| ) | |
| PANOMESH_CFG = Pano2MeshSRConfig() | |
| PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG) | |
| PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000]) | |
| PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT) | |
| prompts = parse_text_prompts(cfg.prompts) | |
| for idx, prompt in enumerate(prompts): | |
| start_time = time.time() | |
| output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| pano_path = os.path.join(output_dir, "pano_image.png") | |
| with open(f"{output_dir}/prompt.txt", "w") as f: | |
| f.write(prompt) | |
| generate_pano_image( | |
| prompt, | |
| pano_path, | |
| IMG2PANO_PIPE, | |
| cfg.seed if cfg.seed is not None else random.randint(0, 100000), | |
| cfg.n_retry, | |
| checker=None if cfg.disable_pano_check else PANO_CHECKER, | |
| ) | |
| if cfg.pano_image_only: | |
| continue | |
| logger.info("GEN and REPAIR Mesh from Panorama...") | |
| PANO2MESH_PIPE(pano_path, output_dir) | |
| logger.info("TRAIN 3DGS from Mesh Init and Cube Image...") | |
| cfg.gs3d.data_dir = output_dir | |
| cfg.gs3d.result_dir = f"{output_dir}/gaussian" | |
| cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler) | |
| torch.set_default_device("cpu") # recover default setting. | |
| cli(gsplat_entrypoint, cfg.gs3d, verbose=True) | |
| # Clean up the middle results. | |
| gs_path = ( | |
| f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply" | |
| ) | |
| copy(gs_path, f"{output_dir}/gs_model.ply") | |
| video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4" | |
| copy(video_path, f"{output_dir}/video.mp4") | |
| gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml" | |
| copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml") | |
| if not cfg.keep_middle_result: | |
| rmtree(cfg.gs3d.result_dir, ignore_errors=True) | |
| os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}") | |
| real_height = ( | |
| PANOHEIGHT_ESTOR(pano_path) | |
| if cfg.real_height is None | |
| else cfg.real_height | |
| ) | |
| gs_path = os.path.join(output_dir, "gs_model.ply") | |
| mesh_path = os.path.join(output_dir, "mesh_model.ply") | |
| restore_scene_scale_and_position(real_height, mesh_path, gs_path) | |
| elapsed_time = (time.time() - start_time) / 60 | |
| logger.info( | |
| f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins." | |
| ) | |
| if __name__ == "__main__": | |
| entrypoint() | |