zaydzuhri's picture
Add files using upload-large-folder tool
3c70147 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torchtitan.tools.logging import logger
__all__ = ["ParallelDims"]
@dataclass
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
def __post_init__(self):
self._validate()
def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
)
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
if dp_shard < 0:
self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp)
assert dp_shard >= 1
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)
def build_mesh(self, device_type: str) -> DeviceMesh:
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
dims.append(d)
names.append(name)
return self._build_mesh(device_type, dims, names, init_device_mesh)
def _build_mesh(
self,
device_type: str,
dims: list[int],
names: list[str],
init_device_mesh_fn: Callable,
) -> DeviceMesh:
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names)
# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []
if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")
dp_shard_cp_mesh_dim_names.append("dp_shard")
dp_cp_mesh_dim_names.append("dp_shard")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")
if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
if dp_shard_cp_mesh_dim_names != []:
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
mesh_dim_name="dp_shard_cp"
)
if dp_cp_mesh_dim_names != []:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
return mesh
@property
def dp_enabled(self):
return self.dp_replicate > 1 or self.dp_shard > 1
@property
def dp_replicate_enabled(self):
return self.dp_replicate > 1
@property
def dp_shard_enabled(self):
return self.dp_shard > 1
@property
def cp_enabled(self):
return self.cp > 1
@property
def tp_enabled(self):
return self.tp > 1
@property
def pp_enabled(self):
return self.pp > 1
@property
def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel
@cached_property
def non_data_parallel_size(self):
return self.cp * self.tp * self.pp