|
|
|
|
|
import contextlib |
|
import functools |
|
import os |
|
from enum import Enum |
|
from functools import lru_cache |
|
from typing import Any, Callable, Dict, Literal, Optional, Tuple |
|
|
|
import torch |
|
import triton |
|
from packaging import version |
|
|
|
|
|
def tensor_cache( |
|
fn: Callable[..., torch.Tensor] |
|
) -> Callable[..., torch.Tensor]: |
|
""" |
|
A decorator that caches the most recent result of a function with tensor inputs. |
|
|
|
This decorator will store the output of the decorated function for the most recent set of input tensors. |
|
If the function is called again with the same input tensors, it will return the cached result. |
|
|
|
|
|
Args: |
|
fn (Callable[..., torch.Tensor]): |
|
The function to be decorated. It should take tensor inputs and return tensor outputs. |
|
|
|
Returns: |
|
Callable[..., torch.Tensor]: |
|
A wrapped version of the input function with single-entry caching. |
|
""" |
|
last_args: Optional[Tuple] = None |
|
last_kwargs: Optional[Dict] = None |
|
last_result: Any = None |
|
|
|
@functools.wraps(fn) |
|
def wrapper(*args: Any, **kwargs: Any) -> Any: |
|
nonlocal last_args, last_kwargs, last_result |
|
|
|
if last_args is not None and last_kwargs is not None: |
|
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): |
|
if all(a is b for a, b in zip(args, last_args)) and \ |
|
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): |
|
return last_result |
|
|
|
result = fn(*args, **kwargs) |
|
last_args, last_kwargs, last_result = args, kwargs, result |
|
return result |
|
|
|
return wrapper |
|
|
|
|
|
def input_guard( |
|
fn: Callable[..., torch.Tensor] |
|
) -> Callable[..., torch.Tensor]: |
|
""" |
|
A decorator to make sure all input tensors are contiguous and set the device based on input tensors. |
|
""" |
|
|
|
@functools.wraps(fn) |
|
def wrapper(*args, **kwargs): |
|
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) |
|
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} |
|
|
|
tensor = None |
|
for arg in args: |
|
if isinstance(arg, torch.Tensor): |
|
tensor = arg |
|
break |
|
if tensor is None: |
|
for value in kwargs.values(): |
|
if isinstance(value, torch.Tensor): |
|
tensor = value |
|
break |
|
|
|
if tensor is not None: |
|
ctx = custom_device_ctx(tensor.device.index) |
|
else: |
|
ctx = contextlib.nullcontext() |
|
|
|
with ctx: |
|
return fn(*contiguous_args, **contiguous_kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
contiguous = input_guard |
|
|
|
|
|
def require_version(version, hint): |
|
""" |
|
Perform a runtime check of the dependency versions, using the exact same syntax used by pip. |
|
""" |
|
def decorator(fn): |
|
@functools.wraps(fn) |
|
def wrapper(ctx, *args, **kwargs): |
|
from transformers.utils.versions import require_version |
|
require_version(version, hint) |
|
return fn(ctx, |
|
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), |
|
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) |
|
return wrapper |
|
return decorator |
|
|
|
|
|
def checkpoint(fn): |
|
def wrapper(*args, **kwargs): |
|
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) |
|
return wrapper |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def check_pytorch_version(version_s: str = '2.4') -> bool: |
|
return version.parse(torch.__version__) >= version.parse(version_s) |
|
|
|
|
|
def _cpu_device_warning(): |
|
import warnings |
|
warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1) |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def get_multiprocessor_count(tensor_idx: int = 0) -> int: |
|
try: |
|
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] |
|
except BaseException: |
|
_cpu_device_warning() |
|
return -1 |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def get_available_device() -> str: |
|
try: |
|
return triton.runtime.driver.active.get_current_target().backend |
|
except BaseException: |
|
_cpu_device_warning() |
|
return 'cpu' |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: |
|
device = get_available_device() |
|
if device == 'cuda': |
|
return 'nvidia' |
|
elif device == 'hip': |
|
return 'amd' |
|
elif device == 'xpu': |
|
return 'intel' |
|
else: |
|
return device |
|
|
|
|
|
|
|
|
|
|
|
device = get_available_device() if get_available_device() != 'hip' else 'cuda' |
|
device_torch_lib = getattr(torch, device) |
|
device_platform = _check_platform() |
|
|
|
is_amd = (device_platform == 'amd') |
|
is_intel = (device_platform == 'intel') |
|
is_nvidia = (device_platform == 'nvidia') |
|
is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) |
|
is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) |
|
use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') |
|
|
|
|
|
is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) |
|
is_gather_supported = hasattr(triton.language, 'gather') |
|
|
|
|
|
def get_all_max_shared_mem(): |
|
try: |
|
return [ |
|
triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem'] |
|
for i in range(device_torch_lib.device_count()) |
|
] |
|
except BaseException: |
|
_cpu_device_warning() |
|
return [-1] |
|
|
|
|
|
class Backend(Enum): |
|
ADA = 101376 |
|
AMPERE = 166912 |
|
HOPPER = 232448 |
|
DEFAULT = 102400 |
|
|
|
@classmethod |
|
def get_shared_memory(cls, arch: str) -> int: |
|
try: |
|
return cls[arch.upper()].value |
|
except KeyError: |
|
return cls.DEFAULT.value |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: |
|
try: |
|
device_shared_mem_list = get_all_max_shared_mem() |
|
max_shared_memory = device_shared_mem_list[tensor_idx] |
|
return max_shared_memory >= Backend.get_shared_memory(arch) |
|
except Exception: |
|
return False |
|
|
|
|
|
if check_pytorch_version('2.4'): |
|
device = 'cuda' if device == 'cpu' else device |
|
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) |
|
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) |
|
|
|
def custom_device_ctx(index: int): |
|
return device_torch_lib.device(index) |
|
else: |
|
assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.' |
|
autocast_custom_fwd = device_torch_lib.amp.custom_fwd |
|
autocast_custom_bwd = device_torch_lib.amp.custom_bwd |
|
|
|
def custom_device_ctx(index: int): |
|
return torch.cuda.device(index) |
|
|