Spaces:
Running
on
Zero
Running
on
Zero
| # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT | |
| # except for the third-party components listed below. | |
| # Hunyuan 3D does not impose any additional limitations beyond what is outlined | |
| # in the repsective licenses of these third-party components. | |
| # Users must comply with all terms and conditions of original licenses of these third-party | |
| # components and must ensure that the usage of the third party components adheres to | |
| # all relevant laws and regulations. | |
| # For avoidance of doubts, Hunyuan 3D means the large language models and | |
| # their software and algorithms, including trained model weights, parameters (including | |
| # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, | |
| # fine-tuning enabling code and other elements of the foregoing made publicly available | |
| # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. | |
| import logging | |
| import os | |
| from functools import wraps | |
| import torch | |
| def get_logger(name): | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.INFO) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| return logger | |
| logger = get_logger('hy3dgen.shapgen') | |
| class synchronize_timer: | |
| """ Synchronized timer to count the inference time of `nn.Module.forward`. | |
| Supports both context manager and decorator usage. | |
| Example as context manager: | |
| ```python | |
| with synchronize_timer('name') as t: | |
| run() | |
| ``` | |
| Example as decorator: | |
| ```python | |
| @synchronize_timer('Export to trimesh') | |
| def export_to_trimesh(mesh_output): | |
| pass | |
| ``` | |
| """ | |
| def __init__(self, name=None): | |
| self.name = name | |
| def __enter__(self): | |
| """Context manager entry: start timing.""" | |
| if os.environ.get('HY3DGEN_DEBUG', '0') == '1': | |
| self.start = torch.cuda.Event(enable_timing=True) | |
| self.end = torch.cuda.Event(enable_timing=True) | |
| self.start.record() | |
| return lambda: self.time | |
| def __exit__(self, exc_type, exc_value, exc_tb): | |
| """Context manager exit: stop timing and log results.""" | |
| if os.environ.get('HY3DGEN_DEBUG', '0') == '1': | |
| self.end.record() | |
| torch.cuda.synchronize() | |
| self.time = self.start.elapsed_time(self.end) | |
| if self.name is not None: | |
| logger.info(f'{self.name} takes {self.time} ms') | |
| def __call__(self, func): | |
| """Decorator: wrap the function to time its execution.""" | |
| def wrapper(*args, **kwargs): | |
| with self: | |
| result = func(*args, **kwargs) | |
| return result | |
| return wrapper | |
| def smart_load_model( | |
| model_path, | |
| subfolder, | |
| use_safetensors, | |
| variant, | |
| ): | |
| original_model_path = model_path | |
| # try local path | |
| base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen') | |
| model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder)) | |
| logger.info(f'Try to load model from local path: {model_path}') | |
| if not os.path.exists(model_path): | |
| logger.info('Model path not exists, try to download from huggingface') | |
| try: | |
| from huggingface_hub import snapshot_download | |
| # 只下载指定子目录 | |
| path = snapshot_download( | |
| repo_id=original_model_path, | |
| allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹 | |
| ) | |
| model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变 | |
| except ImportError: | |
| logger.warning( | |
| "You need to install HuggingFace Hub to load models from the hub." | |
| ) | |
| raise RuntimeError(f"Model path {model_path} not found") | |
| except Exception as e: | |
| raise e | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model path {original_model_path} not found") | |
| extension = 'ckpt' if not use_safetensors else 'safetensors' | |
| variant = '' if variant is None else f'.{variant}' | |
| ckpt_name = f'model{variant}.{extension}' | |
| config_path = os.path.join(model_path, 'config.yaml') | |
| ckpt_path = os.path.join(model_path, ckpt_name) | |
| return config_path, ckpt_path | |