|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributed import DeviceMesh |
|
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard |
|
from torch.distributed._composable.replicate import replicate |
|
from torch.distributed._tensor import Replicate, Shard |
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper |
|
from torch.distributed.tensor.parallel import ( |
|
ColwiseParallel, |
|
PrepareModuleInput, |
|
PrepareModuleOutput, |
|
RowwiseParallel, |
|
SequenceParallel, |
|
parallelize_module |
|
) |
|
|
|
from fla.modules.fused_linear_cross_entropy import LinearLossParallel |
|
from fla.modules.mlp import SwiGLULinearParallel |
|
from fla.modules.parallel import PrepareModuleWeight |
|
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig |
|
from torchtitan.distributed.parallel_dims import ParallelDims |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
def parallelize_fla( |
|
model: nn.Module, |
|
world_mesh: DeviceMesh, |
|
parallel_dims: ParallelDims, |
|
job_config: JobConfig, |
|
): |
|
""" |
|
Apply tensor parallelism, activation checkpointing, torch.compile, and data |
|
parallelism to the model. |
|
|
|
NOTE: The passed-in model preferably should be on meta device. Otherwise, |
|
the model must fit on GPU or CPU memory. |
|
""" |
|
|
|
if parallel_dims.tp_enabled: |
|
if ( |
|
job_config.experimental.enable_async_tensor_parallel |
|
and not job_config.training.compile |
|
): |
|
raise RuntimeError("Async TP requires --training.compile") |
|
enable_float8_linear = "float8" in job_config.model.converters |
|
apply_tp( |
|
model, |
|
world_mesh["tp"], |
|
loss_parallel=parallel_dims.loss_parallel_enabled, |
|
enable_float8=enable_float8_linear, |
|
enable_async_tp=job_config.experimental.enable_async_tensor_parallel, |
|
) |
|
|
|
if job_config.activation_checkpoint.mode != "none": |
|
apply_ac(model, job_config.activation_checkpoint) |
|
|
|
|
|
if job_config.training.compile: |
|
apply_compile(model) |
|
|
|
if ( |
|
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled |
|
): |
|
if parallel_dims.dp_replicate_enabled: |
|
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
|
else: |
|
dp_mesh_dim_names = ("dp_shard_cp",) |
|
|
|
apply_fsdp( |
|
model, |
|
world_mesh[tuple(dp_mesh_dim_names)], |
|
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
|
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
|
pp_enabled=parallel_dims.pp_enabled, |
|
cpu_offload=job_config.training.enable_cpu_offload, |
|
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward, |
|
) |
|
|
|
if parallel_dims.dp_replicate_enabled: |
|
logger.info("Applied HSDP to the model") |
|
else: |
|
logger.info("Applied FSDP to the model") |
|
|
|
if parallel_dims.cp_enabled: |
|
logger.info("Applied Context Parallel to the model") |
|
|
|
if job_config.training.enable_cpu_offload: |
|
logger.info("Applied CPU Offloading to the model") |
|
elif parallel_dims.dp_replicate_enabled: |
|
if world_mesh.ndim > 1: |
|
raise RuntimeError("DDP has not supported > 1D parallelism") |
|
apply_ddp( |
|
model, |
|
world_mesh, |
|
enable_compile=job_config.training.compile, |
|
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, |
|
) |
|
|
|
|
|
class TPPlan: |
|
def __init__( |
|
self, |
|
model=None, |
|
loss_parallel=False, |
|
enable_float8=False, |
|
): |
|
self.model = model |
|
self.loss_parallel = loss_parallel |
|
self.enable_float8 = enable_float8 |
|
self.base_model_prefix = getattr(model, "base_model_prefix", "model") |
|
|
|
|
|
|
|
|
|
try: |
|
from torchao.float8.float8_tensor_parallel import ( |
|
Float8ColwiseParallel, |
|
Float8RowwiseParallel, |
|
PrepareFloat8ModuleInput |
|
) |
|
except ImportError: |
|
Float8ColwiseParallel = None |
|
Float8RowwiseParallel = None |
|
PrepareFloat8ModuleInput = None |
|
if self.enable_float8 and Float8ColwiseParallel is not None: |
|
self.rowwise_parallel = Float8RowwiseParallel |
|
self.colwise_parallel = Float8ColwiseParallel |
|
self.prepare_module_input = PrepareFloat8ModuleInput |
|
self.prepare_module_output = PrepareModuleOutput |
|
else: |
|
self.rowwise_parallel = RowwiseParallel |
|
self.colwise_parallel = ColwiseParallel |
|
self.prepare_module_input = PrepareModuleInput |
|
self.prepare_module_output = PrepareModuleOutput |
|
|
|
@property |
|
def model_plan(self): |
|
plans = { |
|
f"{self.base_model_prefix}.embeddings": RowwiseParallel( |
|
input_layouts=Replicate(), |
|
output_layouts=Shard(1), |
|
), |
|
f"{self.base_model_prefix}.norm": SequenceParallel(), |
|
} |
|
if self.loss_parallel: |
|
plans.update( |
|
{ |
|
"lm_head": ColwiseParallel( |
|
input_layouts=Shard(1), |
|
output_layouts=Shard(-1) if self.loss_parallel else Replicate(), |
|
use_local_output=not self.loss_parallel, |
|
), |
|
} |
|
) |
|
else: |
|
plans.update( |
|
{ |
|
"lm_head": PrepareModuleWeight(layouts=Replicate()), |
|
"criterion": LinearLossParallel(), |
|
} |
|
) |
|
return plans |
|
|
|
@property |
|
def layer_plan(self): |
|
return { |
|
"attn_norm": SequenceParallel(), |
|
**self.attn_plan, |
|
"mlp_norm": SequenceParallel(), |
|
**self.mlp_plan, |
|
} |
|
|
|
@property |
|
def attn_plan(self): |
|
raise NotImplementedError( |
|
f"TP plans for token mixing layers of {self.model.config.model_type} not implemented" |
|
) |
|
|
|
@property |
|
def mlp_plan(self): |
|
return { |
|
"mlp": self.prepare_module_input( |
|
input_layouts=(Shard(1),), |
|
desired_input_layouts=(Replicate(),), |
|
), |
|
"mlp.gate_proj": self.colwise_parallel(), |
|
"mlp.up_proj": self.colwise_parallel(), |
|
"mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)), |
|
"mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)), |
|
} |
|
|
|
|
|
class TransformerTPPlan(TPPlan): |
|
|
|
@property |
|
def attn_plan(self): |
|
return { |
|
"attn": self.prepare_module_input( |
|
input_kwarg_layouts={"hidden_states": Shard(1)}, |
|
desired_input_kwarg_layouts={"hidden_states": Replicate()}, |
|
), |
|
"attn.q_proj": self.colwise_parallel(), |
|
"attn.k_proj": self.colwise_parallel(), |
|
"attn.v_proj": self.colwise_parallel(), |
|
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), |
|
} |
|
|
|
|
|
class GLATPPlan(TPPlan): |
|
|
|
@property |
|
def attn_plan(self): |
|
return { |
|
"attn": self.prepare_module_input( |
|
input_kwarg_layouts={"hidden_states": Shard(1)}, |
|
desired_input_kwarg_layouts={"hidden_states": Replicate()}, |
|
), |
|
"attn.q_proj": self.colwise_parallel(), |
|
"attn.k_proj": self.colwise_parallel(), |
|
"attn.v_proj": self.colwise_parallel(), |
|
"attn.g_proj": self.colwise_parallel(), |
|
"attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()), |
|
"attn.gk_proj.1": self.colwise_parallel(), |
|
"attn.g_norm": SequenceParallel(sequence_dim=-1), |
|
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), |
|
} |
|
|
|
|
|
TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan} |
|
|
|
|
|
def apply_tp( |
|
model: nn.Module, |
|
tp_mesh: DeviceMesh, |
|
loss_parallel: bool, |
|
enable_float8: bool, |
|
enable_async_tp: bool, |
|
): |
|
"""Apply tensor parallelism.""" |
|
|
|
|
|
|
|
|
|
tp_plan = TP_PLAN_MAP[model.config.model_type]( |
|
model, loss_parallel=loss_parallel, enable_float8=enable_float8 |
|
) |
|
parallelize_module(model, tp_mesh, tp_plan.model_plan) |
|
|
|
blocks = get_blocks(model) |
|
if blocks is None: |
|
logger.warning("No block found for tensor parallelism") |
|
else: |
|
for _, block in enumerate(blocks): |
|
parallelize_module( |
|
module=block, |
|
device_mesh=tp_mesh, |
|
parallelize_plan=tp_plan.layer_plan, |
|
) |
|
|
|
if enable_async_tp: |
|
from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
|
|
|
torch._inductor.config._micro_pipeline_tp = True |
|
enable_symm_mem_for_group(tp_mesh.get_group().group_name) |
|
|
|
logger.info( |
|
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" |
|
"Tensor Parallelism to the model" |
|
) |
|
|
|
|
|
|
|
_save_list = { |
|
torch.ops.aten.mm.default, |
|
torch.ops.aten._scaled_dot_product_efficient_attention.default, |
|
torch.ops.aten._scaled_dot_product_flash_attention.default, |
|
torch.ops._c10d_functional.reduce_scatter_tensor.default, |
|
|
|
|
|
|
|
torch.ops.aten.max.default, |
|
} |
|
|
|
|
|
def _apply_ac_to_block(module: nn.Module, ac_config): |
|
valid_ac_modes = ("full", "selective") |
|
if ac_config.mode not in valid_ac_modes: |
|
raise ValueError( |
|
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" |
|
) |
|
|
|
if ac_config.mode == "full": |
|
return ptd_checkpoint_wrapper(module, preserve_rng_state=False) |
|
|
|
assert ac_config.mode == "selective", f"{ac_config.mode}" |
|
use_op_sac = ac_config.selective_ac_option == "op" |
|
use_layer_sac = ac_config.selective_ac_option.isdigit() |
|
if not use_op_sac and not use_layer_sac: |
|
raise ValueError( |
|
f"Invalid selective AC option: {ac_config.selective_ac_option}. " |
|
f"Valid options: 'op' or a positive int representing layer frequency" |
|
) |
|
if use_op_sac: |
|
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts |
|
|
|
def _get_custom_policy(meta): |
|
def _custom_policy(ctx, func, *args, **kwargs): |
|
mode = "recompute" if ctx.is_recompute else "forward" |
|
mm_count_key = f"{mode}_mm_count" |
|
if func == torch.ops.aten.mm.default: |
|
meta[mm_count_key] += 1 |
|
|
|
to_save = func in _save_list and not ( |
|
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 |
|
) |
|
return ( |
|
CheckpointPolicy.MUST_SAVE |
|
if to_save |
|
else CheckpointPolicy.PREFER_RECOMPUTE |
|
) |
|
|
|
return _custom_policy |
|
|
|
def selective_checkpointing_context_fn(): |
|
meta = defaultdict(int) |
|
return create_selective_checkpoint_contexts(_get_custom_policy(meta)) |
|
|
|
return ptd_checkpoint_wrapper( |
|
module, |
|
context_fn=selective_checkpointing_context_fn, |
|
preserve_rng_state=False, |
|
) |
|
elif use_layer_sac: |
|
|
|
ac_freq = int(ac_config.selective_ac_option) |
|
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) |
|
ptd_checkpoint_wrapper._count += 1 |
|
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: |
|
return ptd_checkpoint_wrapper(module, preserve_rng_state=False) |
|
else: |
|
return module |
|
|
|
|
|
def apply_ac(model: nn.Module, ac_config): |
|
"""Apply activation checkpointing to the model.""" |
|
blocks = get_blocks(model) |
|
if blocks is None: |
|
logger.warning("No block found for activation checkpointing") |
|
return |
|
|
|
for layer_id, block in blocks.named_children(): |
|
block = _apply_ac_to_block(block, ac_config) |
|
blocks.register_module(layer_id, block) |
|
|
|
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") |
|
|
|
|
|
def apply_compile(model: nn.Module): |
|
""" |
|
Apply torch.compile to each block, which makes compilation efficient due to |
|
repeated structure. Alternatively one can compile the whole model (after applying DP). |
|
""" |
|
|
|
blocks = get_blocks(model) |
|
if blocks is None: |
|
logger.warning("No block found for torch.compile") |
|
else: |
|
for layer_id, block in blocks.named_children(): |
|
block = torch.compile(block) |
|
blocks.register_module(layer_id, block) |
|
logger.info("Compiling each block with torch.compile") |
|
|
|
real_model = get_model(model) |
|
|
|
logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile") |
|
embeddings_key = get_components_name(real_model, "tok_embeddings") |
|
if embeddings_key is not None: |
|
embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True) |
|
real_model.register_module(embeddings_key, embeddings) |
|
|
|
norm_key = get_components_name(real_model, "norm") |
|
if norm_key is not None: |
|
norm = torch.compile(getattr(real_model, norm_key), fullgraph=True) |
|
real_model.register_module(norm_key, norm) |
|
|
|
lm_head_key = get_components_name(model, "lm_head") |
|
if lm_head_key is not None: |
|
lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True) |
|
model.register_module(lm_head_key, lm_head) |
|
|
|
logger.info("Compiling the entire model with torch.compile") |
|
model = torch.compile(model) |
|
|
|
|
|
def apply_fsdp( |
|
model: nn.Module, |
|
dp_mesh: DeviceMesh, |
|
param_dtype: torch.dtype, |
|
reduce_dtype: torch.dtype, |
|
pp_enabled: bool, |
|
cpu_offload: bool = False, |
|
reshard_after_forward_policy: str = "default", |
|
): |
|
""" |
|
Apply data parallelism (via FSDP2) to the model. |
|
|
|
Args: |
|
model (nn.Module): The model to apply data parallelism to. |
|
dp_mesh (DeviceMesh): The device mesh to use for data parallelism. |
|
param_dtype (torch.dtype): The data type to use for model parameters. |
|
reduce_dtype (torch.dtype): The data type to use for reduction operations. |
|
pp_enabled (bool): Whether pipeline parallelism is enabled. |
|
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. |
|
reshard_after_forward_policy (str, optional): |
|
The policy to use for resharding after forward pass. Defaults to "default". |
|
Other options: "never", "always". |
|
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. |
|
- "always" will enable `reshard_after_forward` for all forward passes. |
|
- "never" will disable `reshard_after_forward` for all forward passes. |
|
|
|
""" |
|
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) |
|
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
|
if cpu_offload: |
|
fsdp_config["offload_policy"] = CPUOffloadPolicy() |
|
|
|
blocks = get_blocks(model) |
|
if blocks is None: |
|
logger.warning("No block found for FSDP") |
|
else: |
|
total_blocks = len(blocks) |
|
for layer_id, block in enumerate(blocks): |
|
if reshard_after_forward_policy == "always": |
|
reshard_after_forward = True |
|
elif reshard_after_forward_policy == "never": |
|
reshard_after_forward = False |
|
elif reshard_after_forward_policy == "default": |
|
if pp_enabled: |
|
|
|
|
|
reshard_after_forward = False |
|
else: |
|
|
|
|
|
reshard_after_forward = int(layer_id) < total_blocks - 1 |
|
else: |
|
raise ValueError( |
|
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." |
|
) |
|
fully_shard( |
|
block, |
|
**fsdp_config, |
|
reshard_after_forward=reshard_after_forward, |
|
) |
|
|
|
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) |
|
|
|
|
|
def apply_ddp( |
|
model: nn.Module, |
|
dp_mesh: DeviceMesh, |
|
enable_compile: bool, |
|
enable_compiled_autograd: bool, |
|
): |
|
if enable_compile: |
|
if enable_compiled_autograd: |
|
torch._dynamo.config.optimize_ddp = ( |
|
"python_reducer_without_compiled_forward" |
|
) |
|
else: |
|
torch._dynamo.config.optimize_ddp = "ddp_optimizer" |
|
|
|
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) |
|
|
|
logger.info("Applied DDP to the model") |
|
|
|
|
|
def get_model(model): |
|
base_model_prefix = getattr(model, "base_model_prefix", "model") |
|
if not hasattr(model, base_model_prefix): |
|
return None |
|
model = getattr(model, base_model_prefix) |
|
return model |
|
|
|
|
|
def get_blocks(model): |
|
|
|
model = get_model(model) |
|
if not hasattr(model, "layers"): |
|
logger.warning('no "layers" in model can be found') |
|
return None |
|
return model.layers |
|
|
|
|
|
def get_components_name(model, component_name): |
|
""" |
|
We try to catch tok_embeddings, norm layers and lm_head layers |
|
We do not catch the layer names in the blocks, for blocks see `get_blocks` |
|
We assume the model has the following structure: |
|
LlamaForCausalLM: |
|
Model: |
|
embed_tokens, |
|
layers, |
|
norm, |
|
lm_head |
|
*** |
|
so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)` |
|
and for 'lm_head' we need to pass `model` |
|
*** |
|
""" |
|
|
|
if component_name == "tok_embeddings": |
|
if hasattr(model, "tok_embeddings"): |
|
return "tok_embeddings" |
|
elif hasattr(model, "embed_tokens"): |
|
return "embed_tokens" |
|
elif hasattr(model, "embeddings"): |
|
return "embeddings" |
|
else: |
|
logger.warning("No tok_embeddings found in model") |
|
return None |
|
|
|
elif component_name == "norm": |
|
if hasattr(model, "norm"): |
|
return "norm" |
|
elif hasattr(model, "norms"): |
|
return "norms" |
|
elif hasattr(model, "layernorm"): |
|
return "layernorm" |
|
else: |
|
logger.warning("No norm found in model") |
|
return None |
|
|
|
elif component_name == "lm_head": |
|
if hasattr(model, "lm_head"): |
|
return "lm_head" |
|
else: |
|
logger.warning("No lm_head found in model") |
|
return None |
|
|