# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # This file applies the PT-D pipeline parallelism to the Llama model. import copy import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import ( _PipelineSchedule, get_schedule_class, ScheduleZBVZeroBubble, ) from torchtitan.components.loss import LossFunction from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline import ( build_pipeline_schedule, generate_split_points, stage_ids_this_rank, ) from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction from torchtitan.tools.logging import logger from .model import TransformerModelArgs def pipeline_llama( model: nn.Module, world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, device: DeviceType, model_config: TransformerModelArgs, parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: pp_mesh = world_mesh["pp"] stages, model_parts = pipeline_llama_manual_split( model, pp_mesh, parallel_dims, job_config, device, model_config ) # For PP with looped schedules, each item in model_parts is one stage-model-chunk. # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing for i, m in enumerate(model_parts): # apply SPMD-style PT-D techniques m = parallelize_fn(m, world_mesh, parallel_dims, job_config) model_parts[i] = m # NOTE: this is to update the model in the stage # in case the model is modified e.g. by torch.compile stages[i].submod = m pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) # This is used in the train loop to determine whether to pass in the input_ids and labels has_first_stage = False has_last_stage = False for stage in stages: if stage.is_first: has_first_stage = True if stage.is_last: has_last_stage = True return pp_schedule, model_parts, has_first_stage, has_last_stage def pipeline_llama_manual_split( whole_model: nn.Module, pp_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, device: DeviceType, model_config: TransformerModelArgs, ) -> tuple[list[PipelineStage], list[nn.Module]]: """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD parallelism. """ pp_rank = pp_mesh.get_local_rank() pp_size = pp_mesh.size() parallelism_config = job_config.parallelism splits = parallelism_config.pipeline_parallel_split_points or generate_split_points( parallelism_config.pipeline_parallel_schedule, parallelism_config.pipeline_parallel_layers_per_stage, parallel_dims.pp, model_config.n_layers, ) def _build_stage( stage_idx: int, start_layer: str | None, stop_layer: str | None, is_first: bool = False, is_last: bool = False, ) -> tuple[PipelineStage, nn.Module]: model = copy.deepcopy(whole_model) if not is_first: model.tok_embeddings = None drop_layers = start_layer is not None for name in list(model.layers.keys()): # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) if f"layers.{name}" == start_layer: drop_layers = False if f"layers.{name}" == stop_layer: drop_layers = True if drop_layers: del model.layers[name] if not is_last: model.norm = None model.output = None stage = PipelineStage( model, stage_idx, num_stages, device, group=pp_mesh.get_group("pp"), ) return stage, model num_stages = len(splits) + 1 stage_idx = pp_rank stages = [] models = [] schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule) style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): start_layer = splits[stage_idx - 1] if stage_idx > 0 else None stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None stage, model_chunk = _build_stage( stage_idx, start_layer, stop_layer, is_first=stage_idx == 0, is_last=stage_idx == num_stages - 1, ) logger.info( f"PP rank {pp_rank} is building stage_idx {stage_idx}" f" with start_layer {start_layer}, stop_layer {stop_layer}" ) stages.append(stage) models.append(model_chunk) return stages, models