File size: 4,158 Bytes
3c70147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torchtitan.tools.logging import logger
__all__ = ["ParallelDims"]
@dataclass
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
def __post_init__(self):
self._validate()
def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
)
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
if dp_shard < 0:
self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp)
assert dp_shard >= 1
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)
def build_mesh(self, device_type: str) -> DeviceMesh:
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
dims.append(d)
names.append(name)
return self._build_mesh(device_type, dims, names, init_device_mesh)
def _build_mesh(
self,
device_type: str,
dims: list[int],
names: list[str],
init_device_mesh_fn: Callable,
) -> DeviceMesh:
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names)
# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []
if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")
dp_shard_cp_mesh_dim_names.append("dp_shard")
dp_cp_mesh_dim_names.append("dp_shard")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")
if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
if dp_shard_cp_mesh_dim_names != []:
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
mesh_dim_name="dp_shard_cp"
)
if dp_cp_mesh_dim_names != []:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
return mesh
@property
def dp_enabled(self):
return self.dp_replicate > 1 or self.dp_shard > 1
@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1
@property
def dp_shard_enabled(self):
return self.dp_shard > 1
@property
def cp_enabled(self):
return self.cp > 1
@property
def tp_enabled(self):
return self.tp > 1
@property
def pp_enabled(self):
return self.pp > 1
@property
def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel
@cached_property
def non_data_parallel_size(self):
return self.cp * self.tp * self.pp
|