|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.distributed._tensor import ( |
|
distribute_tensor, |
|
DTensor, |
|
Partial, |
|
Replicate, |
|
Shard, |
|
) |
|
from torch.utils.checkpoint import ( |
|
checkpoint, |
|
CheckpointPolicy, |
|
create_selective_checkpoint_contexts, |
|
) |
|
|
|
|
|
_active_parametrization = True |
|
|
|
|
|
@contextmanager |
|
def disable_data_parallel(): |
|
global _active_parametrization |
|
try: |
|
_active_parametrization = False |
|
yield |
|
finally: |
|
_active_parametrization = True |
|
|
|
|
|
@dataclass(frozen=True) |
|
class MixedPrecisionPolicy: |
|
param_dtype: Optional[torch.dtype] = None |
|
reduce_dtype: Optional[torch.dtype] = None |
|
|
|
|
|
def fsdp_policy(): |
|
def _fsdp_recomp_policy(): |
|
def _custom_policy(ctx, func, *args, **kwargs): |
|
to_recompute = func in { |
|
torch.ops._c10d_functional.all_gather_into_tensor.default, |
|
torch.ops._c10d_functional.wait_tensor.default, |
|
torch.ops.aten._to_copy.default, |
|
} |
|
return ( |
|
CheckpointPolicy.MUST_RECOMPUTE |
|
if to_recompute |
|
else CheckpointPolicy.MUST_SAVE |
|
) |
|
|
|
return _custom_policy |
|
|
|
return create_selective_checkpoint_contexts(_fsdp_recomp_policy()) |
|
|
|
|
|
class ReplicateComputation(torch.nn.Module): |
|
def __init__(self, device_mesh, param_sharding, mode, regional_ac, mp_policy): |
|
super().__init__() |
|
self.device_mesh = device_mesh |
|
self.param_sharding = param_sharding |
|
self.mode = mode |
|
self.compute_placements = [Replicate()] * self.device_mesh.ndim |
|
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim |
|
self.regional_ac = regional_ac |
|
mp_policy = mp_policy or MixedPrecisionPolicy() |
|
self.param_dtype = mp_policy.param_dtype |
|
self.reduce_dtype = mp_policy.reduce_dtype |
|
|
|
def replicate_compute(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.mode == "fully_shard" and x._spec.mesh.ndim == 2: |
|
dp_placement, tp_placement = x._spec.placements |
|
dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] |
|
|
|
|
|
|
|
sharded_local_tensor = x.to_local() |
|
sharded_dtensor = DTensor.from_local( |
|
sharded_local_tensor, dp_mesh, self.param_sharding |
|
) |
|
|
|
|
|
|
|
|
|
replicated_dtensor = sharded_dtensor.redistribute( |
|
placements=self.compute_placements, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
replicated_local_tensor = replicated_dtensor.to_local( |
|
grad_placements=self.grad_placements |
|
) |
|
output = DTensor.from_local( |
|
replicated_local_tensor, tp_mesh, (tp_placement,) |
|
) |
|
else: |
|
output = x.redistribute( |
|
placements=self.compute_placements, |
|
|
|
|
|
).to_local(grad_placements=self.grad_placements) |
|
|
|
return output |
|
|
|
def forward(self, x): |
|
global _active_parametrization |
|
|
|
|
|
|
|
|
|
|
|
if not _active_parametrization: |
|
return x |
|
|
|
if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"): |
|
|
|
output = checkpoint( |
|
self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy |
|
) |
|
else: |
|
output = self.replicate_compute(x) |
|
|
|
return output |
|
|
|
|
|
def data_parallel( |
|
model, |
|
device_mesh, |
|
mode="replicate", |
|
ac_mode: str = "none", |
|
mp_policy: Optional[MixedPrecisionPolicy] = None, |
|
): |
|
if mode == "replicate": |
|
param_sharding = (Replicate(),) |
|
elif mode == "fully_shard": |
|
param_sharding = (Shard(0),) |
|
elif mode == "hybrid_shard": |
|
|
|
param_sharding = (Replicate(), Shard(0)) |
|
assert ( |
|
device_mesh.ndim == 2 |
|
), "hybrid sharded data parallel requires 2D DeviceMesh" |
|
else: |
|
raise ValueError(f"Unsupported mode {mode}") |
|
|
|
modules = list(model.modules()) |
|
|
|
|
|
regional_ac = ac_mode == "none" |
|
|
|
for mod in modules: |
|
params_dict = dict(mod.named_parameters(recurse=False)) |
|
for p_name, p in params_dict.items(): |
|
if p is not None and p.numel() > 0: |
|
mod.register_parameter( |
|
p_name, |
|
|
|
|
|
|
|
nn.Parameter(distribute_tensor(p, device_mesh, param_sharding)), |
|
) |
|
nn.utils.parametrize.register_parametrization( |
|
mod, |
|
p_name, |
|
ReplicateComputation( |
|
device_mesh, |
|
param_sharding, |
|
mode, |
|
regional_ac, |
|
mp_policy=mp_policy, |
|
), |
|
unsafe=True, |
|
) |
|
|
|
return model |
|
|