zaydzuhri's picture
Add files using upload-large-folder tool
4db79c5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
from collections import defaultdict
import torch
import torch.nn as nn
from torch.distributed._composable.replicate import replicate
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import logger
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
if parallel_dims.tp_enabled:
if (
job_config.parallelism.enable_async_tensor_parallel
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")
enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
"rowwise_with_gw_hp",
)
# For now, float8 all-gather with TP is only supported for tensorwise
# float8 scaling recipes. For rowwise recipes, we use regular TP and
# all-gather happens in high precision.
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)
if job_config.model.use_flex_attn:
if job_config.activation_checkpoint.mode == "selective":
raise ValueError(
"FlexAttention is not compatible with selective AC yet. "
"See https://github.com/pytorch/pytorch/issues/147879"
)
if parallel_dims.cp_enabled:
raise ValueError(
"FlexAttention is not compatible with CP yet. "
"We are still working on this."
)
if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
apply_compile(model)
if (
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
): # apply FSDP or HSDP, potentially with Context Parallel
if parallel_dims.dp_replicate_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
else:
dp_mesh_dim_names = ("dp_shard_cp",)
apply_fsdp(
model,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")
if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
model,
world_mesh,
enable_compile=job_config.training.compile,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
)
return model
def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
)
# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears with tensorwise scaling.
if enable_float8_tensorwise_tp:
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
rowwise_parallel, colwise_parallel, prepare_module_input = (
Float8RowwiseParallel,
Float8ColwiseParallel,
PrepareFloat8ModuleInput,
)
else:
rowwise_parallel, colwise_parallel, prepare_module_input = (
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
)
# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
# for selective op activation checkpointing
_save_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
}
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
)
if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
use_layer_sac = ac_config.selective_ac_option.isdigit()
if not use_op_sac and not use_layer_sac:
raise ValueError(
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
f"Valid options: 'op' or a positive int representing layer frequency"
)
if use_op_sac:
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)
return _custom_policy
def selective_checkpointing_context_fn():
meta = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
)
elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
ptd_checkpoint_wrapper._count += 1
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
else:
return module
def apply_ac(model: nn.Module, ac_config):
"""Apply activation checkpointing to the model."""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
model.layers.register_module(layer_id, transformer_block)
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)
logger.info("Compiling each TransformerBlock with torch.compile")
def apply_fsdp(
model: nn.Module,
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
pp_enabled: bool,
cpu_offload: bool = False,
reshard_after_forward_policy: str = "default",
):
"""
Apply data parallelism (via FSDP2) to the model.
Args:
model (nn.Module): The model to apply data parallelism to.
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
param_dtype (torch.dtype): The data type to use for model parameters.
reduce_dtype (torch.dtype): The data type to use for reduction operations.
pp_enabled (bool): Whether pipeline parallelism is enabled.
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
Other options: "never", "always".
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
- "always" will enable `reshard_after_forward` for all forward passes.
- "never" will disable `reshard_after_forward` for all forward passes.
"""
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()
for layer_id, transformer_block in model.layers.items():
if reshard_after_forward_policy == "always":
reshard_after_forward = True
elif reshard_after_forward_policy == "never":
reshard_after_forward = False
elif reshard_after_forward_policy == "default":
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
else:
raise ValueError(
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
)
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
def apply_ddp(
model: nn.Module,
dp_mesh: DeviceMesh,
enable_compile: bool,
enable_compiled_autograd: bool,
):
if enable_compile:
if enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
logger.info("Applied DDP to the model")