|
|
|
|
|
|
|
|
|
|
|
import math |
|
import os |
|
from typing import Callable, Optional |
|
|
|
from torch.distributed.pipelining.schedules import ( |
|
_PipelineSchedule, |
|
_PipelineScheduleRuntime, |
|
get_schedule_class, |
|
PipelineScheduleMulti, |
|
PipelineScheduleSingle, |
|
) |
|
from torch.distributed.pipelining.stage import PipelineStage |
|
|
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"] |
|
|
|
|
|
|
|
|
|
def generate_split_points( |
|
schedule_str: str, |
|
layers_per_stage: Optional[int], |
|
pp_dim: int, |
|
num_layers: int, |
|
input_weight: int = 1, |
|
output_weight: int = 1, |
|
) -> list[str]: |
|
""" |
|
Generate a list of split points based on the number of layers and |
|
pipeline parallel dimension, ensuring the first and last stages have the least layers. |
|
|
|
Args: |
|
schedule_str (str): The string of the schedule name. |
|
layers_per_stage (int): The number of layers per stage. |
|
pp_dim (int): The pipeline parallel dimension. |
|
num_layers (int): The number of layers in the model. |
|
input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation. |
|
|
|
Returns: |
|
list[str]: A list of split point FQNs. |
|
""" |
|
|
|
schedule_class = get_schedule_class(schedule_str) |
|
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) |
|
num_stages_per_rank = 1 if is_single_stage_schedule else 2 |
|
|
|
if layers_per_stage is not None: |
|
total_stages = math.ceil(num_layers / layers_per_stage) |
|
if total_stages % pp_dim != 0: |
|
raise ValueError( |
|
f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})." |
|
f"Each rank should have the same number of stages. " |
|
) |
|
num_stages_per_rank = total_stages // pp_dim |
|
|
|
if is_single_stage_schedule and num_stages_per_rank != 1: |
|
raise ValueError( |
|
f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules." |
|
) |
|
elif not is_single_stage_schedule and num_stages_per_rank < 2: |
|
raise ValueError( |
|
f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules." |
|
) |
|
else: |
|
total_stages = pp_dim * num_stages_per_rank |
|
if total_stages > num_layers: |
|
raise ValueError("Total stages cannot be greater than the number of layers") |
|
|
|
|
|
effective_num_layers = num_layers + input_weight + output_weight |
|
base_layers_per_stage = effective_num_layers // total_stages |
|
|
|
splits = [""] * (total_stages - 1) |
|
current_layer_index = 0 |
|
|
|
|
|
layers_on_first_stage = max(0, base_layers_per_stage - input_weight) |
|
current_layer_index += layers_on_first_stage |
|
splits[0] = "layers." + str(current_layer_index) |
|
|
|
|
|
layers_on_last_stage = max(0, base_layers_per_stage - output_weight) |
|
splits[-1] = "layers." + str(num_layers - layers_on_last_stage) |
|
|
|
|
|
remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1 |
|
middle_stages = len(splits) - 2 |
|
layers_per_middle_stage = remaining_layers // middle_stages |
|
|
|
remainder = remaining_layers % middle_stages |
|
|
|
for i in range(1, middle_stages + 1): |
|
current_layer_index += layers_per_middle_stage |
|
if remainder > 0: |
|
current_layer_index += 1 |
|
remainder -= 1 |
|
splits[i] = "layers." + str(current_layer_index) |
|
|
|
logger.info( |
|
f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} " |
|
"This may be sub-optimal as the number of layers per stage may be unbalanced." |
|
) |
|
return splits |
|
|
|
|
|
def build_pipeline_schedule( |
|
job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable |
|
) -> _PipelineSchedule: |
|
"""Builds a pipeline schedule for the given job configuration and stages. |
|
|
|
Args: |
|
job_config (JobConfig): The job configuration. |
|
stages (list[PipelineStage]): The stages to be scheduled. |
|
loss_fn (Callable): The loss function. |
|
|
|
Returns: |
|
_PipelineSchedule: The pipeline schedule for the given stages. |
|
""" |
|
pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv |
|
|
|
|
|
if pp_schedule_csv: |
|
if not os.path.isfile(pp_schedule_csv): |
|
raise FileNotFoundError( |
|
f"The specified path {pp_schedule_csv} does not exist or is not a file." |
|
) |
|
schedule_class = _PipelineScheduleRuntime |
|
else: |
|
schedule_class = get_schedule_class( |
|
job_config.parallelism.pipeline_parallel_schedule |
|
) |
|
|
|
looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) |
|
microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size |
|
batch_size = job_config.training.batch_size |
|
|
|
if batch_size % microbatch_size != 0: |
|
raise ValueError( |
|
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. " |
|
"Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." |
|
) |
|
n_microbatches = batch_size // microbatch_size |
|
|
|
num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages) |
|
if n_microbatches < num_total_stages: |
|
logger.warning( |
|
f"Number of microbatches ({n_microbatches}) is less than the total number " |
|
f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." |
|
) |
|
|
|
schedule = schedule_class( |
|
stages if looped_schedule else stages[0], |
|
n_microbatches=n_microbatches, |
|
loss_fn=loss_fn, |
|
) |
|
logger.info( |
|
f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} " |
|
f"with {n_microbatches} microbatches and {num_total_stages} stages." |
|
) |
|
|
|
if pp_schedule_csv: |
|
assert schedule_class in [ |
|
PipelineScheduleSingle, |
|
PipelineScheduleMulti, |
|
_PipelineScheduleRuntime, |
|
], ( |
|
"Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " |
|
"and _PipelineScheduleRuntime support csv schedules" |
|
) |
|
schedule._load_csv(pp_schedule_csv) |
|
|
|
return schedule |
|
|
|
|
|
|
|
def stage_ids_this_rank( |
|
pp_rank: int, pp_size: int, num_stages: int, style: str = "loop" |
|
) -> tuple[int]: |
|
"""Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule""" |
|
assert ( |
|
num_stages % pp_size == 0 |
|
), f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}" |
|
stages_per_rank = num_stages // pp_size |
|
if style == "loop": |
|
return tuple(pp_rank + s * pp_size for s in range(stages_per_rank)) |
|
elif style == "v": |
|
assert ( |
|
stages_per_rank == 2 |
|
), f"v schedules assume 2 stages per rank, got {stages_per_rank}" |
|
stage_v_pairs = list( |
|
zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1)) |
|
) |
|
return stage_v_pairs[pp_rank] |
|
|