Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| """ | |
| Decorators. | |
| """ | |
| import functools | |
| import threading | |
| import time | |
| from typing import Callable | |
| import torch | |
| from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank | |
| from common.logger import get_logger | |
| logger = get_logger(__name__) | |
| def log_on_entry(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator will log the function name at entry. | |
| When using multiple decorators, this must be applied innermost to properly capture the name. | |
| """ | |
| def log_on_entry_wrapper(*args, **kwargs): | |
| logger.info(f"Entering {func.__name__}") | |
| return func(*args, **kwargs) | |
| return log_on_entry_wrapper | |
| def barrier_on_entry(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator will start executing when all ranks are ready to enter. | |
| """ | |
| def barrier_on_entry_wrapper(*args, **kwargs): | |
| barrier_if_distributed() | |
| return func(*args, **kwargs) | |
| return barrier_on_entry_wrapper | |
| def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable: | |
| """ | |
| Helper function for local_rank_zero_only and global_rank_zero_only. | |
| """ | |
| def conditional_execute_wrapper(*args, **kwargs): | |
| # Only execute if needed. | |
| result = func(*args, **kwargs) if execute else None | |
| # All GPUs must wait. | |
| barrier_if_distributed() | |
| # Return results. | |
| return result | |
| return conditional_execute_wrapper | |
| def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable: | |
| """ | |
| Helper function for some functions with special constraints, | |
| especially functions called by other global_rank_zero_only / local_rank_zero_only ones, | |
| in case they are wrongly invoked in other scenarios. | |
| """ | |
| def asserted_execute_wrapper(*args, **kwargs): | |
| assert condition, err_msg | |
| result = func(*args, **kwargs) | |
| return result | |
| return asserted_execute_wrapper | |
| def local_rank_zero_only(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator will only execute on local rank zero. | |
| """ | |
| return _conditional_execute_wrapper_factory(get_local_rank() == 0, func) | |
| def global_rank_zero_only(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator will only execute on global rank zero. | |
| """ | |
| return _conditional_execute_wrapper_factory(get_global_rank() == 0, func) | |
| def assert_only_global_rank_zero(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator are only accessible to processes with global rank zero. | |
| """ | |
| return _asserted_wrapper_factory( | |
| get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0" | |
| ) | |
| def assert_only_local_rank_zero(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator are only accessible to processes with local rank zero. | |
| """ | |
| return _asserted_wrapper_factory( | |
| get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0" | |
| ) | |
| def new_thread(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator will run in a new thread. | |
| The function will return the thread, which can be joined to wait for completion. | |
| """ | |
| def new_thread_wrapper(*args, **kwargs): | |
| thread = threading.Thread(target=func, args=args, kwargs=kwargs) | |
| thread.start() | |
| return thread | |
| return new_thread_wrapper | |
| def log_runtime(func: Callable) -> Callable: | |
| """ | |
| Functions with this decorator will logging the runtime. | |
| """ | |
| def wrapped(*args, **kwargs): | |
| torch.distributed.barrier() | |
| start = time.perf_counter() | |
| result = func(*args, **kwargs) | |
| torch.distributed.barrier() | |
| logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.") | |
| return result | |
| return wrapped | |