|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from typing import Callable, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributed import DeviceMesh |
|
from torch.distributed.pipelining import PipelineStage |
|
from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class |
|
from transformers import PretrainedConfig |
|
|
|
from flame.models.parallelize_fla import get_blocks, get_components_name, get_model |
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.distributed.parallel_dims import ParallelDims |
|
from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank |
|
from torchtitan.tools.logging import logger |
|
|
|
DeviceType = Union[int, str, torch.device] |
|
|
|
|
|
def pipeline_fla( |
|
model: nn.Module, |
|
pp_mesh: DeviceMesh, |
|
parallel_dims: ParallelDims, |
|
job_config: JobConfig, |
|
device: DeviceType, |
|
model_config: PretrainedConfig, |
|
loss_fn: Callable[..., torch.Tensor], |
|
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: |
|
stages, models = pipeline_fla_manual_split( |
|
model, pp_mesh, parallel_dims, job_config, device, model_config |
|
) |
|
|
|
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) |
|
|
|
|
|
has_first_stage = False |
|
has_last_stage = False |
|
for stage in stages: |
|
if stage.is_first: |
|
has_first_stage = True |
|
if stage.is_last: |
|
has_last_stage = True |
|
|
|
return pp_schedule, models, has_first_stage, has_last_stage |
|
|
|
|
|
def pipeline_fla_manual_split( |
|
whole_model: nn.Module, |
|
pp_mesh: DeviceMesh, |
|
parallel_dims: ParallelDims, |
|
job_config: JobConfig, |
|
device: DeviceType, |
|
model_config: PretrainedConfig, |
|
) -> tuple[list[PipelineStage], list[nn.Module]]: |
|
""" |
|
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. |
|
|
|
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. |
|
|
|
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD |
|
parallelism. |
|
""" |
|
pp_rank = pp_mesh.get_local_rank() |
|
pp_size = pp_mesh.size() |
|
|
|
splits = ( |
|
job_config.experimental.pipeline_parallel_split_points |
|
or generate_split_points( |
|
job_config, parallel_dims.pp, model_config.num_hidden_layers |
|
) |
|
) |
|
|
|
def _build_stage( |
|
stage_idx: int, |
|
start_layer: Optional[str], |
|
stop_layer: Optional[str], |
|
is_first: bool = False, |
|
is_last: bool = False, |
|
) -> tuple[PipelineStage, nn.Module]: |
|
model = copy.deepcopy(whole_model) |
|
if not is_first: |
|
|
|
real_model = get_model(model) |
|
tok_embeddings_name = get_components_name(real_model, "tok_embeddings") |
|
setattr(real_model, tok_embeddings_name, None) |
|
|
|
drop_layers = start_layer is not None |
|
|
|
|
|
module_dict = get_blocks(model)._modules |
|
layer_names = list(module_dict.keys()) |
|
|
|
|
|
for name in layer_names: |
|
|
|
prefix = start_layer.split(".")[0] if start_layer else "layers" |
|
layer_name = f"{prefix}.{name}" |
|
|
|
|
|
if layer_name == start_layer: |
|
drop_layers = False |
|
if layer_name == stop_layer: |
|
drop_layers = True |
|
|
|
|
|
if drop_layers: |
|
del module_dict[name] |
|
|
|
if not is_last: |
|
|
|
real_model = get_model(model) |
|
norm_name = get_components_name(real_model, "norm") |
|
setattr(real_model, norm_name, None) |
|
|
|
head_name = get_components_name(model, "lm_head") |
|
setattr(model, head_name, None) |
|
|
|
stage = PipelineStage( |
|
model, |
|
stage_idx, |
|
num_stages, |
|
device, |
|
group=pp_mesh.get_group("pp"), |
|
) |
|
return stage, model |
|
|
|
num_stages = len(splits) + 1 |
|
stage_idx = pp_rank |
|
|
|
stages = [] |
|
models = [] |
|
|
|
schedule_class = get_schedule_class( |
|
job_config.experimental.pipeline_parallel_schedule |
|
) |
|
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" |
|
|
|
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): |
|
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None |
|
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None |
|
stage, model_chunk = _build_stage( |
|
stage_idx, |
|
start_layer, |
|
stop_layer, |
|
is_first=stage_idx == 0, |
|
is_last=stage_idx == num_stages - 1, |
|
) |
|
logger.info( |
|
f"PP rank {pp_rank} is building stage_idx {stage_idx}" |
|
f" with start_layer {start_layer}, stop_layer {stop_layer}" |
|
) |
|
stages.append(stage) |
|
models.append(model_chunk) |
|
return stages, models |
|
|