File size: 5,458 Bytes
4db79c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file applies the PT-D pipeline parallelism to the Llama model.
import copy
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
get_schedule_class,
ScheduleZBVZeroBubble,
)
from torchtitan.components.loss import LossFunction
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.pipeline import (
build_pipeline_schedule,
generate_split_points,
stage_ids_this_rank,
)
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
from torchtitan.tools.logging import logger
from .model import TransformerModelArgs
def pipeline_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: TransformerModelArgs,
parallelize_fn: ParallelizeFunction,
loss_fn: LossFunction,
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
pp_mesh = world_mesh["pp"]
stages, model_parts = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True
return pp_schedule, model_parts, has_first_stage, has_last_stage
def pipeline_llama_manual_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: TransformerModelArgs,
) -> tuple[list[PipelineStage], list[nn.Module]]:
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
parallelism.
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
parallelism_config = job_config.parallelism
splits = parallelism_config.pipeline_parallel_split_points or generate_split_points(
parallelism_config.pipeline_parallel_schedule,
parallelism_config.pipeline_parallel_layers_per_stage,
parallel_dims.pp,
model_config.n_layers,
)
def _build_stage(
stage_idx: int,
start_layer: str | None,
stop_layer: str | None,
is_first: bool = False,
is_last: bool = False,
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)
if not is_first:
model.tok_embeddings = None
drop_layers = start_layer is not None
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if f"layers.{name}" == start_layer:
drop_layers = False
if f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]
if not is_last:
model.norm = None
model.output = None
stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
return stage, model
num_stages = len(splits) + 1
stage_idx = pp_rank
stages = []
models = []
schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
f" with start_layer {start_layer}, stop_layer {stop_layer}"
)
stages.append(stage)
models.append(model_chunk)
return stages, models
|