|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
from typing import Union |
|
from warnings import warn |
|
|
|
import psutil |
|
import torch |
|
from torch import nn |
|
from torch.autograd.graph import saved_tensors_hooks |
|
|
|
from torchtitan.tools.logging import logger |
|
|
|
try: |
|
import torchao |
|
from torchao.dtypes.nf4tensor import NF4Tensor |
|
except ImportError: |
|
torchao = None |
|
NF4Tensor = None |
|
logger.warning("torchao not found. ") |
|
|
|
|
|
|
|
|
|
class OffloadActivations(saved_tensors_hooks): |
|
"""Context manager under which activation tensors created in the forward pass will be offloaded. |
|
|
|
Enable the memory efficiency technique of activation offloading, where activations bigger than |
|
min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward. |
|
This is in contrast to maintaining the activation on GPU VRAM throughout the program. |
|
|
|
This manager contains the option of using one additional CUDA stream to handle the communication |
|
between CUDA and CPU, which is intended to overlap with the default computation stream to improve |
|
runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between |
|
runtime vs memory usage. |
|
|
|
Args: |
|
use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned |
|
memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly |
|
but is a limited resource. Default: True. |
|
|
|
use_streams (bool): Whether or not to use streams for performance optimization where |
|
the communications get overlapped with the computation. Requires a torch build |
|
after torch-2.5.0.]. Default: True. |
|
|
|
max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of |
|
consecutive activations to keep alive during the forward pass. This number must be at |
|
least 1. Keeping alive more activations will potentially allow more overlap between the |
|
communication and compute streams at the cost of increasing memory usage. Keeping alive |
|
fewer activations will conserve memory, but may cause poor overlap between the streams, |
|
increasing runtime. Default: 5. |
|
|
|
min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify |
|
for offloading. If the tensor is too small, we do not want to waste bandwidth and resources |
|
moving it to CPU and back. Default: 1024 bytes. |
|
|
|
Raises: |
|
ValueError: if max_fwd_stash_size is not at least 1. |
|
|
|
Example: |
|
>>> with OffloadActivations(): |
|
>>> logits = model(inputs) |
|
>>> loss = ... |
|
>>> loss.backward() |
|
""" |
|
|
|
def __init__( |
|
self, |
|
use_pin_memory: bool = True, |
|
use_streams: bool = True, |
|
max_fwd_stash_size: int = 5, |
|
min_offload_size: int = 1024, |
|
) -> None: |
|
|
|
self.use_streams: bool = use_streams |
|
|
|
self.min_tensor_size_bytes = ( |
|
min_offload_size |
|
) |
|
self.tracker = ( |
|
{} |
|
) |
|
self.tensor_id: int = 0 |
|
self.is_first_forward_call = True |
|
self.is_first_backward_call = True |
|
self.is_first_forward_pass = True |
|
|
|
|
|
self.use_pin_memory: bool = use_pin_memory |
|
self.virtual_memory_safe_pct = ( |
|
60 |
|
) |
|
|
|
self.s0 = torch.cuda.default_stream() |
|
|
|
|
|
if self.use_streams: |
|
self.s1 = torch.cuda.Stream() |
|
self.fwd_stash = {} |
|
if max_fwd_stash_size < 1: |
|
raise ValueError( |
|
f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}" |
|
) |
|
self.max_fwd_stash_size = max_fwd_stash_size |
|
self.bwd_tensor_stash = {} |
|
self.bwd_ev_stash = {} |
|
self.curr_graph_id = None |
|
self.curr_autograd_node = None |
|
|
|
|
|
def verify_sufficient_virtual_memory(): |
|
curr_pct = get_cpu_ram_pct() |
|
if curr_pct > self.virtual_memory_safe_pct: |
|
warn( |
|
f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used" |
|
) |
|
|
|
def get_cpu_ram_pct() -> float: |
|
|
|
return psutil.virtual_memory().percent |
|
|
|
def get_tensor_id() -> int: |
|
|
|
self.tensor_id += 1 |
|
return self.tensor_id |
|
|
|
def get_num_bytes_tensor(x: torch.Tensor) -> int: |
|
|
|
return ( |
|
x.element_size() * x.nelement() |
|
) |
|
|
|
|
|
def pack_tensor(activation: torch.Tensor) -> int: |
|
|
|
if self.is_first_forward_call: |
|
assert ( |
|
len(self.tracker) == 0 |
|
), "backward pass should have cleared tracker of all tensors" |
|
|
|
|
|
self.is_first_forward_call = False |
|
self.is_first_backward_call = True |
|
|
|
|
|
num_bytes = get_num_bytes_tensor(activation) |
|
tensor_id = get_tensor_id() |
|
|
|
|
|
|
|
if ( |
|
activation.is_cuda |
|
and num_bytes >= self.min_tensor_size_bytes |
|
and ( |
|
not isinstance(activation, torch.nn.Parameter) |
|
and not isinstance(activation, torch.nn.Buffer) |
|
) |
|
): |
|
if self.use_streams: |
|
|
|
|
|
for id in [k for k in self.fwd_stash.keys()]: |
|
if id <= tensor_id - self.max_fwd_stash_size: |
|
_, ev = self.fwd_stash[id] |
|
self.s0.wait_event(ev) |
|
del self.fwd_stash[id] |
|
else: |
|
break |
|
|
|
|
|
self.s1.wait_stream(self.s0) |
|
|
|
stream = self.s1 if self.use_streams else self.s0 |
|
with torch.cuda.stream(stream): |
|
try: |
|
cpu_tensor = torch.empty_like( |
|
activation, pin_memory=self.use_pin_memory, device="cpu" |
|
) |
|
except NotImplementedError as e: |
|
if ( |
|
isinstance(activation, NF4Tensor) |
|
and torchao.__version__ < "0.6.0.dev20240917" |
|
): |
|
raise RuntimeError( |
|
"Offloading NF4Tensors requires torchao-0.6.0.dev20240917 or later" |
|
) from e |
|
raise e |
|
cpu_tensor.copy_(activation, non_blocking=True) |
|
self.tracker[tensor_id] = ( |
|
cpu_tensor, |
|
True, |
|
) |
|
|
|
if self.use_streams: |
|
event = self.s1.record_event() |
|
|
|
|
|
self.fwd_stash[tensor_id] = (activation, event) |
|
else: |
|
self.tracker[tensor_id] = ( |
|
activation, |
|
False, |
|
) |
|
|
|
return tensor_id |
|
|
|
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: |
|
|
|
|
|
if self.is_first_backward_call: |
|
if self.is_first_forward_pass: |
|
self.is_first_forward_pass = False |
|
if self.use_pin_memory: |
|
verify_sufficient_virtual_memory() |
|
|
|
self.is_first_backward_call = False |
|
self.is_first_forward_call = True |
|
|
|
assert ( |
|
unpack_tensor_id in self.tracker |
|
), f"untracked tensor with id {unpack_tensor_id}" |
|
|
|
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] |
|
if modified: |
|
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) |
|
maybe_gpu_tensor = gpu_tensor |
|
|
|
|
|
del self.tracker[unpack_tensor_id] |
|
return maybe_gpu_tensor |
|
|
|
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: |
|
|
|
|
|
if self.is_first_backward_call: |
|
self.curr_graph_id = torch._C._current_graph_task_id() |
|
|
|
def wait_and_del_remaining_references() -> None: |
|
for id in [k for k in self.bwd_tensor_stash.keys()]: |
|
event = self.bwd_ev_stash[id] |
|
self.s1.wait_event(event) |
|
del self.bwd_tensor_stash[id] |
|
|
|
|
|
torch.autograd.variable.Variable._execution_engine.queue_callback( |
|
wait_and_del_remaining_references |
|
) |
|
|
|
if self.is_first_forward_pass: |
|
self.is_first_forward_pass = False |
|
if self.use_pin_memory: |
|
verify_sufficient_virtual_memory() |
|
|
|
self.is_first_backward_call = False |
|
self.is_first_forward_call = True |
|
|
|
assert ( |
|
unpack_tensor_id in self.tracker |
|
), f"untracked tensor with id {unpack_tensor_id}" |
|
|
|
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] |
|
if modified: |
|
|
|
graph_id = torch._C._current_graph_task_id() |
|
node = torch._C._current_autograd_node() |
|
prev_node_ids = [] |
|
|
|
|
|
if graph_id == self.curr_graph_id and self.curr_autograd_node != node: |
|
self.curr_autograd_node = node |
|
prev_node_ids = [id for id in self.bwd_tensor_stash.keys()] |
|
|
|
brought_back_from_cpu = True |
|
if unpack_tensor_id in self.fwd_stash: |
|
maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] |
|
brought_back_from_cpu = False |
|
else: |
|
|
|
with torch.cuda.stream(self.s1): |
|
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True) |
|
maybe_gpu_tensor = gpu_tensor |
|
|
|
|
|
self.s0.wait_stream(self.s1) |
|
|
|
|
|
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
storage_refcount = torch._C._storage_Use_Count( |
|
maybe_gpu_tensor.untyped_storage()._cdata |
|
) |
|
|
|
def hook(outputs, inputs): |
|
|
|
if brought_back_from_cpu: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] |
|
if ( |
|
torch._C._storage_Use_Count( |
|
unpacked_tensor.untyped_storage()._cdata |
|
) |
|
> storage_refcount |
|
): |
|
unpacked_tensor.record_stream(self.s0) |
|
del self.bwd_tensor_stash[unpack_tensor_id] |
|
else: |
|
event = self.s0.record_event() |
|
self.bwd_ev_stash[unpack_tensor_id] = event |
|
|
|
|
|
for id in [k for k in self.fwd_stash.keys()]: |
|
_, ev = self.fwd_stash[id] |
|
self.s0.wait_event(ev) |
|
del self.fwd_stash[id] |
|
|
|
|
|
for id in prev_node_ids: |
|
event = self.bwd_ev_stash[id] |
|
self.s1.wait_event(event) |
|
del self.bwd_tensor_stash[id] |
|
|
|
return outputs |
|
|
|
node.register_hook(hook) |
|
|
|
|
|
del self.tracker[unpack_tensor_id] |
|
return maybe_gpu_tensor |
|
|
|
unpack_tensor = ( |
|
unpack_tensor_with_streams |
|
if self.use_streams |
|
else unpack_tensor_single_stream |
|
) |
|
super().__init__(pack_tensor, unpack_tensor) |
|
|
|
|
|
class NoOpManager(saved_tensors_hooks): |
|
""" |
|
A saved_tensors_hook manager used to disable any other saved_tensors_hook manager |
|
applied before. This relies on the behavior that only the most recently registered |
|
saved_tensors_hook will run. |
|
|
|
One example usage is to opt a local region of code out of activations offloading, |
|
which is usually applied globally to best track state. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
def noop(tensor): |
|
return tensor |
|
|
|
super().__init__(noop, noop) |
|
|
|
|
|
def get_act_offloading_ctx_manager( |
|
model: nn.Module, enable_activation_offloading: bool |
|
) -> Union[OffloadActivations, contextlib.nullcontext]: |
|
"""Returns the activation offloading context manager for the model, which will be |
|
a null context if enable_activation_offloading is False. |
|
|
|
If activation offloading is enabled, we return the OffloadActivations context manager. |
|
If activation offloading is disabled, we return a NoOpManager context manager. |
|
|
|
Args: |
|
model (nn.Module): the model to wrap with the activation offloading context manager. |
|
enable_activation_offloading (bool): whether or not to enable activation offloading |
|
for the model. |
|
|
|
Returns: |
|
contextlib.ContextDecorator: the activation offloading context manager for the model. |
|
|
|
Raises: |
|
NotImplementedError: If the model is a multimodal model and activation offloading is enabled. |
|
""" |
|
if enable_activation_offloading: |
|
activations_handling_ctx = OffloadActivations() |
|
|
|
|
|
|
|
|
|
|
|
output_head_detected = False |
|
noop_ctx = NoOpManager() |
|
|
|
if hasattr(model, "output"): |
|
if isinstance(model.output, nn.Module): |
|
model.output.register_forward_pre_hook( |
|
lambda *args: noop_ctx.__enter__() |
|
) |
|
model.output.register_forward_hook( |
|
lambda *args: noop_ctx.__exit__(), always_call=True |
|
) |
|
print("registering hooks for model.output ============ ") |
|
output_head_detected = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not output_head_detected: |
|
logger.warning( |
|
"During activation offloading, no output head was detected. " |
|
"If your model has an output head, it will be offloaded. " |
|
"This usually greatly slows training, given the large vocabulary size. " |
|
"To change this behavior, set your output head as model.output and make it " |
|
"an nn.Module." |
|
) |
|
|
|
else: |
|
activations_handling_ctx = contextlib.nullcontext() |
|
|
|
return activations_handling_ctx |
|
|