|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributed._composable.replicate import replicate |
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
checkpoint_wrapper as ptd_checkpoint_wrapper, |
|
) |
|
|
|
from torch.distributed.device_mesh import DeviceMesh |
|
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy |
|
from torch.distributed.tensor import Replicate, Shard |
|
from torch.distributed.tensor.parallel import ( |
|
ColwiseParallel, |
|
parallelize_module, |
|
PrepareModuleInput, |
|
RowwiseParallel, |
|
SequenceParallel, |
|
) |
|
|
|
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP |
|
from torchtitan.distributed import ParallelDims |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
def parallelize_llama( |
|
model: nn.Module, |
|
world_mesh: DeviceMesh, |
|
parallel_dims: ParallelDims, |
|
job_config: JobConfig, |
|
): |
|
""" |
|
Apply tensor parallelism, activation checkpointing, torch.compile, and data |
|
parallelism to the model. |
|
|
|
NOTE: The passed-in model preferably should be on meta device. Otherwise, |
|
the model must fit on GPU or CPU memory. |
|
""" |
|
|
|
if parallel_dims.tp_enabled: |
|
if ( |
|
job_config.parallelism.enable_async_tensor_parallel |
|
and not job_config.training.compile |
|
): |
|
raise RuntimeError("Async TP requires --training.compile") |
|
|
|
enable_float8_linear = "float8" in job_config.model.converters |
|
float8_is_rowwise = job_config.float8.recipe_name in ( |
|
"rowwise", |
|
"rowwise_with_gw_hp", |
|
) |
|
|
|
|
|
|
|
|
|
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise |
|
|
|
apply_tp( |
|
model, |
|
world_mesh["tp"], |
|
loss_parallel=parallel_dims.loss_parallel_enabled, |
|
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, |
|
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, |
|
) |
|
|
|
if job_config.model.use_flex_attn: |
|
if job_config.activation_checkpoint.mode == "selective": |
|
raise ValueError( |
|
"FlexAttention is not compatible with selective AC yet. " |
|
"See https://github.com/pytorch/pytorch/issues/147879" |
|
) |
|
|
|
if parallel_dims.cp_enabled: |
|
raise ValueError( |
|
"FlexAttention is not compatible with CP yet. " |
|
"We are still working on this." |
|
) |
|
|
|
if job_config.activation_checkpoint.mode != "none": |
|
apply_ac(model, job_config.activation_checkpoint) |
|
|
|
|
|
if job_config.training.compile: |
|
apply_compile(model) |
|
|
|
if ( |
|
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled |
|
): |
|
if parallel_dims.dp_replicate_enabled: |
|
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
|
else: |
|
dp_mesh_dim_names = ("dp_shard_cp",) |
|
|
|
apply_fsdp( |
|
model, |
|
world_mesh[tuple(dp_mesh_dim_names)], |
|
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
|
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
|
pp_enabled=parallel_dims.pp_enabled, |
|
cpu_offload=job_config.training.enable_cpu_offload, |
|
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, |
|
) |
|
|
|
if parallel_dims.dp_replicate_enabled: |
|
logger.info("Applied HSDP to the model") |
|
else: |
|
logger.info("Applied FSDP to the model") |
|
|
|
if parallel_dims.cp_enabled: |
|
logger.info("Applied Context Parallel to the model") |
|
|
|
if job_config.training.enable_cpu_offload: |
|
logger.info("Applied CPU Offloading to the model") |
|
elif parallel_dims.dp_replicate_enabled: |
|
if world_mesh.ndim > 1: |
|
raise RuntimeError("DDP has not supported > 1D parallelism") |
|
apply_ddp( |
|
model, |
|
world_mesh, |
|
enable_compile=job_config.training.compile, |
|
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, |
|
) |
|
|
|
return model |
|
|
|
|
|
def apply_tp( |
|
model: nn.Module, |
|
tp_mesh: DeviceMesh, |
|
loss_parallel: bool, |
|
enable_float8_tensorwise_tp: bool, |
|
enable_async_tp: bool, |
|
): |
|
"""Apply tensor parallelism.""" |
|
|
|
|
|
|
|
|
|
parallelize_module( |
|
model, |
|
tp_mesh, |
|
{ |
|
"tok_embeddings": RowwiseParallel( |
|
input_layouts=Replicate(), |
|
output_layouts=Shard(1), |
|
), |
|
"norm": SequenceParallel(), |
|
"output": ColwiseParallel( |
|
input_layouts=Shard(1), |
|
output_layouts=Shard(-1) if loss_parallel else Replicate(), |
|
use_local_output=not loss_parallel, |
|
), |
|
}, |
|
) |
|
|
|
|
|
|
|
if enable_float8_tensorwise_tp: |
|
|
|
from torchao.float8.float8_tensor_parallel import ( |
|
Float8ColwiseParallel, |
|
Float8RowwiseParallel, |
|
PrepareFloat8ModuleInput, |
|
) |
|
|
|
rowwise_parallel, colwise_parallel, prepare_module_input = ( |
|
Float8RowwiseParallel, |
|
Float8ColwiseParallel, |
|
PrepareFloat8ModuleInput, |
|
) |
|
else: |
|
rowwise_parallel, colwise_parallel, prepare_module_input = ( |
|
RowwiseParallel, |
|
ColwiseParallel, |
|
PrepareModuleInput, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
for layer_id, transformer_block in model.layers.items(): |
|
layer_plan = { |
|
"attention_norm": SequenceParallel(), |
|
"attention": prepare_module_input( |
|
input_layouts=(Shard(1), None), |
|
desired_input_layouts=(Replicate(), None), |
|
), |
|
"attention.wq": colwise_parallel(), |
|
"attention.wk": colwise_parallel(), |
|
"attention.wv": colwise_parallel(), |
|
"attention.wo": rowwise_parallel(output_layouts=Shard(1)), |
|
"ffn_norm": SequenceParallel(), |
|
"feed_forward": prepare_module_input( |
|
input_layouts=(Shard(1),), |
|
desired_input_layouts=(Replicate(),), |
|
), |
|
"feed_forward.w1": colwise_parallel(), |
|
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), |
|
"feed_forward.w3": colwise_parallel(), |
|
} |
|
|
|
parallelize_module( |
|
module=transformer_block, |
|
device_mesh=tp_mesh, |
|
parallelize_plan=layer_plan, |
|
) |
|
|
|
if enable_async_tp: |
|
from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
|
|
|
torch._inductor.config._micro_pipeline_tp = True |
|
enable_symm_mem_for_group(tp_mesh.get_group().group_name) |
|
|
|
logger.info( |
|
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" |
|
"Tensor Parallelism to the model" |
|
) |
|
|
|
|
|
|
|
_save_list = { |
|
torch.ops.aten.mm.default, |
|
torch.ops.aten._scaled_dot_product_efficient_attention.default, |
|
torch.ops.aten._scaled_dot_product_flash_attention.default, |
|
|
|
|
|
|
|
torch.ops.aten.max.default, |
|
} |
|
|
|
|
|
def _apply_ac_to_transformer_block(module: nn.Module, ac_config): |
|
valid_ac_modes = ("full", "selective") |
|
if ac_config.mode not in valid_ac_modes: |
|
raise ValueError( |
|
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" |
|
) |
|
|
|
if ac_config.mode == "full": |
|
return ptd_checkpoint_wrapper(module, preserve_rng_state=False) |
|
|
|
assert ac_config.mode == "selective", f"{ac_config.mode}" |
|
use_op_sac = ac_config.selective_ac_option == "op" |
|
use_layer_sac = ac_config.selective_ac_option.isdigit() |
|
if not use_op_sac and not use_layer_sac: |
|
raise ValueError( |
|
f"Invalid selective AC option: {ac_config.selective_ac_option}. " |
|
f"Valid options: 'op' or a positive int representing layer frequency" |
|
) |
|
if use_op_sac: |
|
from torch.utils.checkpoint import ( |
|
CheckpointPolicy, |
|
create_selective_checkpoint_contexts, |
|
) |
|
|
|
def _get_custom_policy(meta): |
|
def _custom_policy(ctx, func, *args, **kwargs): |
|
mode = "recompute" if ctx.is_recompute else "forward" |
|
mm_count_key = f"{mode}_mm_count" |
|
if func == torch.ops.aten.mm.default: |
|
meta[mm_count_key] += 1 |
|
|
|
to_save = func in _save_list and not ( |
|
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 |
|
) |
|
return ( |
|
CheckpointPolicy.MUST_SAVE |
|
if to_save |
|
else CheckpointPolicy.PREFER_RECOMPUTE |
|
) |
|
|
|
return _custom_policy |
|
|
|
def selective_checkpointing_context_fn(): |
|
meta = defaultdict(int) |
|
return create_selective_checkpoint_contexts(_get_custom_policy(meta)) |
|
|
|
return ptd_checkpoint_wrapper( |
|
module, |
|
context_fn=selective_checkpointing_context_fn, |
|
preserve_rng_state=False, |
|
) |
|
elif use_layer_sac: |
|
|
|
ac_freq = int(ac_config.selective_ac_option) |
|
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) |
|
ptd_checkpoint_wrapper._count += 1 |
|
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: |
|
return ptd_checkpoint_wrapper(module, preserve_rng_state=False) |
|
else: |
|
return module |
|
|
|
|
|
def apply_ac(model: nn.Module, ac_config): |
|
"""Apply activation checkpointing to the model.""" |
|
for layer_id, transformer_block in model.layers.named_children(): |
|
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) |
|
model.layers.register_module(layer_id, transformer_block) |
|
|
|
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") |
|
|
|
|
|
def apply_compile(model: nn.Module): |
|
""" |
|
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to |
|
repeated structure. Alternatively one can compile the whole model (after applying DP). |
|
""" |
|
for layer_id, transformer_block in model.layers.named_children(): |
|
transformer_block = torch.compile(transformer_block, fullgraph=True) |
|
model.layers.register_module(layer_id, transformer_block) |
|
|
|
logger.info("Compiling each TransformerBlock with torch.compile") |
|
|
|
|
|
def apply_fsdp( |
|
model: nn.Module, |
|
dp_mesh: DeviceMesh, |
|
param_dtype: torch.dtype, |
|
reduce_dtype: torch.dtype, |
|
pp_enabled: bool, |
|
cpu_offload: bool = False, |
|
reshard_after_forward_policy: str = "default", |
|
): |
|
""" |
|
Apply data parallelism (via FSDP2) to the model. |
|
|
|
Args: |
|
model (nn.Module): The model to apply data parallelism to. |
|
dp_mesh (DeviceMesh): The device mesh to use for data parallelism. |
|
param_dtype (torch.dtype): The data type to use for model parameters. |
|
reduce_dtype (torch.dtype): The data type to use for reduction operations. |
|
pp_enabled (bool): Whether pipeline parallelism is enabled. |
|
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. |
|
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". |
|
Other options: "never", "always". |
|
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. |
|
- "always" will enable `reshard_after_forward` for all forward passes. |
|
- "never" will disable `reshard_after_forward` for all forward passes. |
|
|
|
""" |
|
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) |
|
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
|
if cpu_offload: |
|
fsdp_config["offload_policy"] = CPUOffloadPolicy() |
|
|
|
for layer_id, transformer_block in model.layers.items(): |
|
if reshard_after_forward_policy == "always": |
|
reshard_after_forward = True |
|
elif reshard_after_forward_policy == "never": |
|
reshard_after_forward = False |
|
elif reshard_after_forward_policy == "default": |
|
if pp_enabled: |
|
|
|
|
|
reshard_after_forward = False |
|
else: |
|
|
|
|
|
reshard_after_forward = int(layer_id) < len(model.layers) - 1 |
|
else: |
|
raise ValueError( |
|
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." |
|
) |
|
fully_shard( |
|
transformer_block, |
|
**fsdp_config, |
|
reshard_after_forward=reshard_after_forward, |
|
) |
|
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) |
|
|
|
|
|
def apply_ddp( |
|
model: nn.Module, |
|
dp_mesh: DeviceMesh, |
|
enable_compile: bool, |
|
enable_compiled_autograd: bool, |
|
): |
|
if enable_compile: |
|
if enable_compiled_autograd: |
|
torch._dynamo.config.optimize_ddp = ( |
|
"python_reducer_without_compiled_forward" |
|
) |
|
else: |
|
torch._dynamo.config.optimize_ddp = "ddp_optimizer" |
|
|
|
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) |
|
|
|
logger.info("Applied DDP to the model") |
|
|