Spaces:
Runtime error
Runtime error
| from datetime import timedelta | |
| from functools import partial | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, StateDictType | |
| from torch.distributed.fsdp.api import CPUOffload | |
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy | |
| def fsdp_state_dict(model): | |
| fsdp_fullstate_save_policy = FullStateDictConfig( | |
| offload_to_cpu=True, rank0_only=True | |
| ) | |
| with FSDP.state_dict_type( | |
| model, StateDictType.FULL_STATE_DICT, fsdp_fullstate_save_policy | |
| ): | |
| checkpoint = model.state_dict() | |
| return checkpoint | |
| def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, wrap_strategy="size", min_num_params=int(5e7), transformer_module=None, cpu_offload=False): | |
| if mixed_precision: | |
| mixed_precision_policy = MixedPrecision( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.float32, | |
| buffer_dtype=torch.float32, | |
| cast_forward_inputs=False | |
| ) | |
| else: | |
| mixed_precision_policy = None | |
| if wrap_strategy == "transformer": | |
| auto_wrap_policy = partial( | |
| transformer_auto_wrap_policy, | |
| transformer_layer_cls=transformer_module | |
| ) | |
| elif wrap_strategy == "size": | |
| auto_wrap_policy = partial( | |
| size_based_auto_wrap_policy, | |
| min_num_params=min_num_params | |
| ) | |
| else: | |
| raise ValueError(f"Invalid wrap strategy: {wrap_strategy}") | |
| os.environ["NCCL_CROSS_NIC"] = "1" | |
| sharding_strategy = { | |
| "full": ShardingStrategy.FULL_SHARD, | |
| "hybrid_full": ShardingStrategy.HYBRID_SHARD, | |
| "hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2, | |
| "no_shard": ShardingStrategy.NO_SHARD, | |
| }[sharding_strategy] | |
| module = FSDP( | |
| module, | |
| auto_wrap_policy=auto_wrap_policy, | |
| sharding_strategy=sharding_strategy, | |
| mixed_precision=mixed_precision_policy, | |
| device_id=torch.cuda.current_device(), | |
| limit_all_gathers=True, | |
| use_orig_params=True, | |
| cpu_offload=CPUOffload(offload_params=cpu_offload), | |
| sync_module_states=False # Load ckpt on rank 0 and sync to other ranks | |
| ) | |
| return module | |
| def barrier(): | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| def launch_distributed_job(backend: str = "nccl"): | |
| rank = int(os.environ["RANK"]) | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| host = os.environ["MASTER_ADDR"] | |
| port = int(os.environ["MASTER_PORT"]) | |
| if ":" in host: # IPv6 | |
| init_method = f"tcp://[{host}]:{port}" | |
| else: # IPv4 | |
| init_method = f"tcp://{host}:{port}" | |
| dist.init_process_group(rank=rank, world_size=world_size, backend=backend, | |
| init_method=init_method, timeout=timedelta(minutes=30)) | |
| torch.cuda.set_device(local_rank) | |
| class EMA_FSDP: | |
| def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999): | |
| self.decay = decay | |
| self.shadow = {} | |
| self._init_shadow(fsdp_module) | |
| def _init_shadow(self, fsdp_module): | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| with FSDP.summon_full_params(fsdp_module, writeback=False): | |
| for n, p in fsdp_module.module.named_parameters(): | |
| self.shadow[n] = p.detach().clone().float().cpu() | |
| def update(self, fsdp_module): | |
| d = self.decay | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| with FSDP.summon_full_params(fsdp_module, writeback=False): | |
| for n, p in fsdp_module.module.named_parameters(): | |
| self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d) | |
| # Optional helpers --------------------------------------------------- | |
| def state_dict(self): | |
| return self.shadow # picklable | |
| def load_state_dict(self, sd): | |
| self.shadow = {k: v.clone() for k, v in sd.items()} | |
| def copy_to(self, fsdp_module): | |
| # load EMA weights into an (unwrapped) copy of the generator | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| with FSDP.summon_full_params(fsdp_module, writeback=True): | |
| for n, p in fsdp_module.module.named_parameters(): | |
| if n in self.shadow: | |
| p.data.copy_(self.shadow[n].to(p.dtype, device=p.device)) | |