File size: 7,277 Bytes
0298ad2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# -*- coding: utf-8 -*-
import contextlib
import functools
import os
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, Literal, Optional, Tuple
import torch
import triton
from packaging import version
def tensor_cache(
fn: Callable[..., torch.Tensor]
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
If the function is called again with the same input tensors, it will return the cached result.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
last_args: Optional[Tuple] = None
last_kwargs: Optional[Dict] = None
last_result: Any = None
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if last_args is not None and last_kwargs is not None:
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
if all(a is b for a, b in zip(args, last_args)) and \
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
return last_result
result = fn(*args, **kwargs)
last_args, last_kwargs, last_result = args, kwargs, result
return result
return wrapper
def input_guard(
fn: Callable[..., torch.Tensor]
) -> Callable[..., torch.Tensor]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = custom_device_ctx(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
contiguous = input_guard
def require_version(version, hint):
"""
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
"""
def decorator(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
from transformers.utils.versions import require_version
require_version(version, hint)
return fn(ctx,
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
return wrapper
return decorator
def checkpoint(fn):
def wrapper(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
return wrapper
@lru_cache(maxsize=None)
def check_pytorch_version(version_s: str = '2.4') -> bool:
return version.parse(torch.__version__) >= version.parse(version_s)
def _cpu_device_warning():
import warnings
warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1)
@lru_cache(maxsize=None)
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
try:
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count']
except BaseException:
_cpu_device_warning()
return -1
@lru_cache(maxsize=None)
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
_cpu_device_warning()
return 'cpu'
@lru_cache(maxsize=None)
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
device = get_available_device()
if device == 'cuda':
return 'nvidia'
elif device == 'hip':
return 'amd'
elif device == 'xpu':
return 'intel'
else:
return device
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
device_torch_lib = getattr(torch, device)
device_platform = _check_platform()
is_amd = (device_platform == 'amd')
is_intel = (device_platform == 'intel')
is_nvidia = (device_platform == 'nvidia')
is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9))
use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
# Nvidia Ampere or newer, haven't check AMD and intel yet.
is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8)
is_gather_supported = hasattr(triton.language, 'gather')
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem']
for i in range(device_torch_lib.device_count())
]
except BaseException:
_cpu_device_warning()
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@lru_cache(maxsize=None)
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False
if check_pytorch_version('2.4'):
device = 'cuda' if device == 'cpu' else device
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
def custom_device_ctx(index: int):
return device_torch_lib.device(index)
else:
assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.'
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
def custom_device_ctx(index: int):
return torch.cuda.device(index)
|