zaydzuhri's picture
Add files using upload-large-folder tool
4db79c5 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file applies the PT-D pipeline parallelism to the Llama model.
import copy
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
get_schedule_class,
ScheduleZBVZeroBubble,
)
from torchtitan.components.loss import LossFunction
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.pipeline import (
build_pipeline_schedule,
generate_split_points,
stage_ids_this_rank,
)
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
from torchtitan.tools.logging import logger
from .model import TransformerModelArgs
def pipeline_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: TransformerModelArgs,
parallelize_fn: ParallelizeFunction,
loss_fn: LossFunction,
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
pp_mesh = world_mesh["pp"]
stages, model_parts = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True
return pp_schedule, model_parts, has_first_stage, has_last_stage
def pipeline_llama_manual_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: TransformerModelArgs,
) -> tuple[list[PipelineStage], list[nn.Module]]:
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
parallelism.
"""
pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
parallelism_config = job_config.parallelism
splits = parallelism_config.pipeline_parallel_split_points or generate_split_points(
parallelism_config.pipeline_parallel_schedule,
parallelism_config.pipeline_parallel_layers_per_stage,
parallel_dims.pp,
model_config.n_layers,
)
def _build_stage(
stage_idx: int,
start_layer: str | None,
stop_layer: str | None,
is_first: bool = False,
is_last: bool = False,
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)
if not is_first:
model.tok_embeddings = None
drop_layers = start_layer is not None
for name in list(model.layers.keys()):
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
if f"layers.{name}" == start_layer:
drop_layers = False
if f"layers.{name}" == stop_layer:
drop_layers = True
if drop_layers:
del model.layers[name]
if not is_last:
model.norm = None
model.output = None
stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
return stage, model
num_stages = len(splits) + 1
stage_idx = pp_rank
stages = []
models = []
schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
f" with start_layer {start_layer}, stop_layer {stop_layer}"
)
stages.append(stage)
models.append(model_chunk)
return stages, models