# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # This file applies the PT-D parallelisms (except pipeline parallelism) and various # training techniques (e.g. activation checkpointing and compile) to the Llama model. from collections import defaultdict import torch import torch.nn as nn from torch.distributed._composable.replicate import replicate from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, PrepareModuleInput, RowwiseParallel, SequenceParallel, ) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger def parallelize_llama( model: nn.Module, world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data parallelism to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ if parallel_dims.tp_enabled: if ( job_config.parallelism.enable_async_tensor_parallel and not job_config.training.compile ): raise RuntimeError("Async TP requires --training.compile") enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.float8.recipe_name in ( "rowwise", "rowwise_with_gw_hp", ) # For now, float8 all-gather with TP is only supported for tensorwise # float8 scaling recipes. For rowwise recipes, we use regular TP and # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise apply_tp( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) if job_config.model.use_flex_attn: if job_config.activation_checkpoint.mode == "selective": raise ValueError( "FlexAttention is not compatible with selective AC yet. " "See https://github.com/pytorch/pytorch/issues/147879" ) if parallel_dims.cp_enabled: raise ValueError( "FlexAttention is not compatible with CP yet. " "We are still working on this." ) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if job_config.training.compile: apply_compile(model) if ( parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): # apply FSDP or HSDP, potentially with Context Parallel if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: dp_mesh_dim_names = ("dp_shard_cp",) apply_fsdp( model, world_mesh[tuple(dp_mesh_dim_names)], param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ) if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") else: logger.info("Applied FSDP to the model") if parallel_dims.cp_enabled: logger.info("Applied Context Parallel to the model") if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: if world_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, world_mesh, enable_compile=job_config.training.compile, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model def apply_tp( model: nn.Module, tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, enable_async_tp: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim # 3. Parallelize the final linear output layer parallelize_module( model, tp_mesh, { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), output_layouts=Shard(1), ), "norm": SequenceParallel(), "output": ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ), }, ) # Parallel styles used for transformer block linear weights and their # inputs may be different for float8 linears with tensorwise scaling. if enable_float8_tensorwise_tp: # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, PrepareFloat8ModuleInput, ) rowwise_parallel, colwise_parallel, prepare_module_input = ( Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput, ) else: rowwise_parallel, colwise_parallel, prepare_module_input = ( RowwiseParallel, ColwiseParallel, PrepareModuleInput, ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 for layer_id, transformer_block in model.layers.items(): layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), "attention.wv": colwise_parallel(), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": colwise_parallel(), "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), "feed_forward.w3": colwise_parallel(), } parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, ) if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" "Tensor Parallelism to the model" ) # for selective op activation checkpointing _save_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, } def _apply_ac_to_transformer_block(module: nn.Module, ac_config): valid_ac_modes = ("full", "selective") if ac_config.mode not in valid_ac_modes: raise ValueError( f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" ) if ac_config.mode == "full": return ptd_checkpoint_wrapper(module, preserve_rng_state=False) assert ac_config.mode == "selective", f"{ac_config.mode}" use_op_sac = ac_config.selective_ac_option == "op" use_layer_sac = ac_config.selective_ac_option.isdigit() if not use_op_sac and not use_layer_sac: raise ValueError( f"Invalid selective AC option: {ac_config.selective_ac_option}. " f"Valid options: 'op' or a positive int representing layer frequency" ) if use_op_sac: from torch.utils.checkpoint import ( CheckpointPolicy, create_selective_checkpoint_contexts, ) def _get_custom_policy(meta): def _custom_policy(ctx, func, *args, **kwargs): mode = "recompute" if ctx.is_recompute else "forward" mm_count_key = f"{mode}_mm_count" if func == torch.ops.aten.mm.default: meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm to_save = func in _save_list and not ( func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 ) return ( CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE ) return _custom_policy def selective_checkpointing_context_fn(): meta = defaultdict(int) return create_selective_checkpoint_contexts(_get_custom_policy(meta)) return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False, ) elif use_layer_sac: # Checkpoint every `ac_freq` of the modules passed to this function ac_freq = int(ac_config.selective_ac_option) ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) ptd_checkpoint_wrapper._count += 1 if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: return ptd_checkpoint_wrapper(module, preserve_rng_state=False) else: return module def apply_ac(model: nn.Module, ac_config): """Apply activation checkpointing to the model.""" for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") def apply_compile(model: nn.Module): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile(transformer_block, fullgraph=True) model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") def apply_fsdp( model: nn.Module, dp_mesh: DeviceMesh, param_dtype: torch.dtype, reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, reshard_after_forward_policy: str = "default", ): """ Apply data parallelism (via FSDP2) to the model. Args: model (nn.Module): The model to apply data parallelism to. dp_mesh (DeviceMesh): The device mesh to use for data parallelism. param_dtype (torch.dtype): The data type to use for model parameters. reduce_dtype (torch.dtype): The data type to use for reduction operations. pp_enabled (bool): Whether pipeline parallelism is enabled. cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". Other options: "never", "always". - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() for layer_id, transformer_block in model.layers.items(): if reshard_after_forward_policy == "always": reshard_after_forward = True elif reshard_after_forward_policy == "never": reshard_after_forward = False elif reshard_after_forward_policy == "default": if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False else: # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(layer_id) < len(model.layers) - 1 else: raise ValueError( f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." ) fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) def apply_ddp( model: nn.Module, dp_mesh: DeviceMesh, enable_compile: bool, enable_compiled_autograd: bool, ): if enable_compile: if enable_compiled_autograd: torch._dynamo.config.optimize_ddp = ( "python_reducer_without_compiled_forward" ) else: torch._dynamo.config.optimize_ddp = "ddp_optimizer" replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model")