|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Wrapper around FSDP for more convenient use in the training loops. |
|
""" |
|
|
|
from contextlib import contextmanager |
|
import typing as tp |
|
import dora |
|
import torch |
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.distributed.fsdp import ( |
|
MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) |
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor |
|
|
|
|
|
def is_fsdp_used() -> bool: |
|
"""Return whether we are using FSDP.""" |
|
|
|
if dora.is_xp(): |
|
cfg = dora.get_xp().cfg |
|
if hasattr(cfg, 'fsdp'): |
|
return cfg.fsdp.use |
|
return False |
|
|
|
|
|
def is_sharded_tensor(x: tp.Any) -> bool: |
|
return isinstance(x, ShardedTensor) |
|
|
|
|
|
@contextmanager |
|
def switch_to_full_state_dict(models: tp.List[FSDP]): |
|
|
|
|
|
for model in models: |
|
FSDP.set_state_dict_type( |
|
model, StateDictType.FULL_STATE_DICT, |
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) |
|
try: |
|
yield |
|
finally: |
|
for model in models: |
|
FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) |
|
|
|
|
|
def wrap_with_fsdp(cfg, model: torch.nn.Module, |
|
block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: |
|
"""Wraps a model with FSDP.""" |
|
|
|
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
|
|
|
|
|
from ..modules.transformer import StreamingTransformerLayer |
|
from ..modules.conditioners import ConditioningProvider |
|
|
|
_fix_post_backward_hook() |
|
|
|
assert cfg.use |
|
sharding_strategy_dict = { |
|
"no_shard": ShardingStrategy.NO_SHARD, |
|
"shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, |
|
"full_shard": ShardingStrategy.FULL_SHARD, |
|
} |
|
|
|
dtype_dict = { |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"bfloat16": torch.bfloat16, |
|
} |
|
|
|
mixed_precision_config = MixedPrecision( |
|
param_dtype=dtype_dict[cfg.param_dtype], |
|
reduce_dtype=dtype_dict[cfg.reduce_dtype], |
|
buffer_dtype=dtype_dict[cfg.buffer_dtype], |
|
) |
|
|
|
sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] |
|
|
|
|
|
|
|
|
|
assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ |
|
"Not supported at the moment, requires a bit more work." |
|
|
|
local_rank = dora.distrib.get_distrib_spec().local_rank |
|
assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" |
|
|
|
auto_wrap_policy = None |
|
if block_classes is None: |
|
block_classes = {StreamingTransformerLayer, ConditioningProvider} |
|
if cfg.per_block: |
|
auto_wrap_policy = ModuleWrapPolicy(block_classes) |
|
wrapped = _FSDPFixStateDict( |
|
model, |
|
sharding_strategy=sharding_strategy_config, |
|
mixed_precision=mixed_precision_config, |
|
device_id=local_rank, |
|
sync_module_states=True, |
|
use_orig_params=True, |
|
auto_wrap_policy=auto_wrap_policy, |
|
) |
|
FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) |
|
|
|
|
|
|
|
|
|
|
|
for module in FSDP.fsdp_modules(wrapped): |
|
original = module._fsdp_wrapped_module |
|
original.__dict__['_fsdp'] = module |
|
return wrapped |
|
|
|
|
|
def purge_fsdp(model: FSDP): |
|
"""Purge the FSDP cached shard inside the model. This should |
|
allow setting the best state or switching to the EMA. |
|
""" |
|
from torch.distributed.fsdp._runtime_utils import _reshard |
|
for module in FSDP.fsdp_modules(model): |
|
handles = module._handles |
|
if not handles: |
|
continue |
|
handle = handles[0] |
|
unsharded_flat_param = handle._get_padded_unsharded_flat_param() |
|
storage_size: int = unsharded_flat_param._typed_storage()._size() |
|
if storage_size == 0: |
|
continue |
|
true_list = [True for h in handles] |
|
_reshard(module, handles, true_list) |
|
|
|
|
|
class _FSDPFixStateDict(FSDP): |
|
@staticmethod |
|
def _name_without_fsdp_prefix(name: str) -> str: |
|
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE |
|
parts = name.split('.') |
|
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] |
|
return '.'.join(new_parts) |
|
|
|
def state_dict(self) -> tp.Dict[str, tp.Any]: |
|
state = dict(super().state_dict()) |
|
for key, value in list(state.items()): |
|
if is_sharded_tensor(value): |
|
del state[key] |
|
return state |
|
|
|
def load_state_dict(self, state: tp.Dict[str, tp.Any]): |
|
if self._state_dict_type is StateDictType.FULL_STATE_DICT: |
|
super().load_state_dict(state) |
|
purge_fsdp(self) |
|
return |
|
|
|
|
|
current_state = dict(super().state_dict()) |
|
for key, value in state.items(): |
|
key = _FSDPFixStateDict._name_without_fsdp_prefix(key) |
|
if key not in current_state: |
|
|
|
raise RuntimeError(f"Unknown state key {key}") |
|
current_state[key].copy_(value) |
|
|
|
|
|
purge_fsdp(self) |
|
|
|
|
|
_hook_fixed = False |
|
|
|
|
|
def _fix_post_backward_hook(): |
|
global _hook_fixed |
|
if _hook_fixed: |
|
return |
|
_hook_fixed = True |
|
|
|
from torch.distributed.fsdp import _runtime_utils |
|
from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState |
|
old_hook = _runtime_utils._post_backward_hook |
|
|
|
def _post_backward_hook(state, handle, *args, **kwargs): |
|
checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) |
|
if checkpointed: |
|
|
|
|
|
|
|
state.training_state = TrainingState.FORWARD_BACKWARD |
|
handle._training_state = HandleTrainingState.BACKWARD_PRE |
|
old_hook(state, handle, *args, **kwargs) |
|
|
|
_runtime_utils._post_backward_hook = _post_backward_hook |
|
|