|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
|
import torch.nn as nn |
|
from torch.distributed import DeviceMesh |
|
from torch.distributed.pipelining import PipelineStage |
|
from torch.distributed.pipelining.schedules import ( |
|
_PipelineSchedule, |
|
get_schedule_class, |
|
ScheduleZBVZeroBubble, |
|
) |
|
|
|
from torchtitan.components.loss import LossFunction |
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.distributed import ParallelDims |
|
from torchtitan.distributed.pipeline import ( |
|
build_pipeline_schedule, |
|
generate_split_points, |
|
stage_ids_this_rank, |
|
) |
|
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction |
|
from torchtitan.tools.logging import logger |
|
|
|
from .model import TransformerModelArgs |
|
|
|
|
|
def pipeline_llama( |
|
model: nn.Module, |
|
world_mesh: DeviceMesh, |
|
parallel_dims: ParallelDims, |
|
job_config: JobConfig, |
|
device: DeviceType, |
|
model_config: TransformerModelArgs, |
|
parallelize_fn: ParallelizeFunction, |
|
loss_fn: LossFunction, |
|
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: |
|
pp_mesh = world_mesh["pp"] |
|
|
|
stages, model_parts = pipeline_llama_manual_split( |
|
model, pp_mesh, parallel_dims, job_config, device, model_config |
|
) |
|
|
|
|
|
|
|
|
|
for i, m in enumerate(model_parts): |
|
|
|
m = parallelize_fn(m, world_mesh, parallel_dims, job_config) |
|
model_parts[i] = m |
|
|
|
|
|
stages[i].submod = m |
|
|
|
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) |
|
|
|
|
|
has_first_stage = False |
|
has_last_stage = False |
|
for stage in stages: |
|
if stage.is_first: |
|
has_first_stage = True |
|
if stage.is_last: |
|
has_last_stage = True |
|
|
|
return pp_schedule, model_parts, has_first_stage, has_last_stage |
|
|
|
|
|
def pipeline_llama_manual_split( |
|
whole_model: nn.Module, |
|
pp_mesh: DeviceMesh, |
|
parallel_dims: ParallelDims, |
|
job_config: JobConfig, |
|
device: DeviceType, |
|
model_config: TransformerModelArgs, |
|
) -> tuple[list[PipelineStage], list[nn.Module]]: |
|
""" |
|
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. |
|
|
|
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. |
|
|
|
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD |
|
parallelism. |
|
""" |
|
pp_rank = pp_mesh.get_local_rank() |
|
pp_size = pp_mesh.size() |
|
parallelism_config = job_config.parallelism |
|
|
|
splits = parallelism_config.pipeline_parallel_split_points or generate_split_points( |
|
parallelism_config.pipeline_parallel_schedule, |
|
parallelism_config.pipeline_parallel_layers_per_stage, |
|
parallel_dims.pp, |
|
model_config.n_layers, |
|
) |
|
|
|
def _build_stage( |
|
stage_idx: int, |
|
start_layer: str | None, |
|
stop_layer: str | None, |
|
is_first: bool = False, |
|
is_last: bool = False, |
|
) -> tuple[PipelineStage, nn.Module]: |
|
model = copy.deepcopy(whole_model) |
|
if not is_first: |
|
model.tok_embeddings = None |
|
|
|
drop_layers = start_layer is not None |
|
for name in list(model.layers.keys()): |
|
|
|
if f"layers.{name}" == start_layer: |
|
drop_layers = False |
|
if f"layers.{name}" == stop_layer: |
|
drop_layers = True |
|
if drop_layers: |
|
del model.layers[name] |
|
|
|
if not is_last: |
|
model.norm = None |
|
model.output = None |
|
|
|
stage = PipelineStage( |
|
model, |
|
stage_idx, |
|
num_stages, |
|
device, |
|
group=pp_mesh.get_group("pp"), |
|
) |
|
return stage, model |
|
|
|
num_stages = len(splits) + 1 |
|
stage_idx = pp_rank |
|
|
|
stages = [] |
|
models = [] |
|
|
|
schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule) |
|
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" |
|
|
|
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): |
|
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None |
|
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None |
|
stage, model_chunk = _build_stage( |
|
stage_idx, |
|
start_layer, |
|
stop_layer, |
|
is_first=stage_idx == 0, |
|
is_last=stage_idx == num_stages - 1, |
|
) |
|
logger.info( |
|
f"PP rank {pp_rank} is building stage_idx {stage_idx}" |
|
f" with start_layer {start_layer}, stop_layer {stop_layer}" |
|
) |
|
stages.append(stage) |
|
models.append(model_chunk) |
|
return stages, models |
|
|