Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2023 The T5X Authors and The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utilities for partitioning.""" | |
| import abc | |
| import collections | |
| import dataclasses | |
| import typing | |
| from typing import Any, Callable, Optional, Sequence, Tuple, Union | |
| import cached_property | |
| import jax | |
| import numpy as np | |
| from absl import logging | |
| from flax import traverse_util | |
| from flax.linen import partitioning as flax_partitioning | |
| from jax import numpy as jnp | |
| from jax import random | |
| from jax.experimental import multihost_utils | |
| from jax.experimental.mesh_utils import create_hybrid_device_mesh | |
| from jax.experimental.pjit import pjit as jax_pjit | |
| from jax.sharding import Mesh, PartitionSpec | |
| JaxDevice = Any | |
| TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). | |
| OtherMesh = Tuple[int, int] | |
| HardwareMesh = Union[TpuMesh, OtherMesh] | |
| PyTreeDef = type(jax.tree_util.tree_structure(None)) | |
| TrainState = Any | |
| LogicalAxisRules = Sequence[Tuple[str, Optional[str]]] | |
| if typing.TYPE_CHECKING: # See b/163639353 | |
| cached_property = property # pylint: disable=invalid-name | |
| else: | |
| cached_property = cached_property.cached_property | |
| class AxisNames(tuple): | |
| """Tuple of strings specifying name for each axis. | |
| We create a separate class for this so JAX's pytree utilities can distinguish | |
| it from a tuple that should be treated as a pytree, instead treating it as a | |
| leaf. | |
| """ | |
| def __new__(cls, *names): | |
| return tuple.__new__(AxisNames, names) | |
| def __repr__(self): | |
| return "AxisNames%s" % tuple.__repr__(self) | |
| # pjit wrappers for cpu fallback. | |
| # ---------------------------------------------------------------------------- | |
| # TODO(levskaya): This function is now no different than jax_pjit, but callers | |
| # currently depend on `backend` argument | |
| def pjit( | |
| fun: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = (), | |
| backend: Optional[str] = None, | |
| ): | |
| """Wrapper for pjit.""" | |
| del backend | |
| return jax_pjit( | |
| fun, in_axis_resources, out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums | |
| ) | |
| # pjit wrappers for cpu fallback. | |
| # ----------------------------------------------------------------------------- | |
| # TODO(levskaya): upstream this fallback behavior to jax pjit. | |
| def pjit_with_cpu_fallback( | |
| fun: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = (), | |
| backend: Optional[str] = None, | |
| ): | |
| """Wrapper for pjit that calls normal jit on cpu.""" | |
| if jax.devices(backend)[0].platform == "cpu": | |
| return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums) | |
| else: | |
| return jax_pjit( | |
| fun, in_axis_resources, out_axis_resources, static_argnums=static_argnums, donate_argnums=donate_argnums | |
| ) | |
| def with_sharding_constraint(x, axis_resources): | |
| """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit.""" | |
| if jax.devices()[0].platform == "cpu" or not global_mesh_defined(): | |
| return x | |
| else: | |
| return jax.experimental.pjit.with_sharding_constraint(x, axis_resources) | |
| # pjit Mesh creation functions. | |
| # ----------------------------------------------------------------------------- | |
| def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh: | |
| """Get the bound from the given last device.""" | |
| # Must be passed the device at the highest-coordinate corner of the | |
| # relevant mesh, which is a requirement we know is satisfied by the last | |
| # device in jax.devices(). | |
| if hasattr(last_device, "coords"): | |
| x, y, z = last_device.coords | |
| return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 | |
| else: | |
| # On non-TPU platforms, the "mesh" is hosts x devices per host in order | |
| # to take advantage of faster within-host interconnect. | |
| return jax.host_count(), jax.local_device_count() | |
| def get_coords(device: JaxDevice) -> HardwareMesh: | |
| """Returns the coordinates of the given device.""" | |
| if hasattr(device, "coords"): | |
| return (*device.coords, device.core_on_chip) | |
| return (device.process_index, device.id % jax.local_device_count()) | |
| def global_mesh_defined(): | |
| """Checks if global xmap/pjit mesh resource environment is defined.""" | |
| maps_env = jax.experimental.maps.thread_resources.env | |
| return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison | |
| def get_mesh( | |
| model_parallel_submesh: HardwareMesh, | |
| input_devices: Sequence[JaxDevice] = (), | |
| input_local_devices: Sequence[JaxDevice] = (), | |
| tile_by_host_if_needed: bool = True, | |
| backend: Optional[str] = None, | |
| ) -> Mesh: | |
| """Construct an xmap/pjit Mesh for the given model-parallel submesh. | |
| The resulting mesh has two resource axes: 'model', with the provided submesh | |
| shape, and 'data', which covers the rest of the mesh. | |
| Args: | |
| model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for | |
| a single model-parallel replica's "tile" in the physical device mesh. The | |
| first three elements (`x`, `y`, and `z`) should be factors of the pod | |
| slice; e.g., if you are using df_4x8, then `x` should be a factor of 4 | |
| (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z` | |
| must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4 | |
| (and maybe later TPUs) that allow 3D slices. `core` is the number of cores | |
| to use from each TPU node. As communication is usually fastest inside the | |
| same node, if you need a tile of more than 1 core, then | |
| you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better | |
| than (2,1,1,1). To pick a good spec, try a few possible values until you | |
| get high TPU utilization. | |
| input_devices: the devices to use, will use jax.devices() if this is not | |
| set. | |
| input_local_devices: the local devices to use, will use jax.local_devices() | |
| if this is not set. | |
| tile_by_host_if_needed: JAX currently requires that the parts of any sharded | |
| array that are located on one host's local devices form a single | |
| contiguous slice. A best effort will be made to achieve this without | |
| "tiling" the device assignment over hosts (which can reduce XLA collective | |
| performance). If this flag is True, then the device assignment will be | |
| tiled over hosts if necessary to satisfy this constraint and create a | |
| buildable mesh; if false, mesh construction will fail instead. | |
| backend: get devices from the pinned backend, if specified. This is | |
| useful for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| Returns: | |
| A xmap / pjit Mesh containing the virtual device mesh with data, model axes. | |
| """ | |
| input_devices = input_devices or jax.devices(backend) | |
| input_local_devices = input_local_devices or jax.local_devices(0, backend) | |
| # Sort input_devices based on coords, as backends might not return devices | |
| # in order. | |
| last_device = sorted(input_devices, key=get_coords)[-1] | |
| last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1] | |
| logging.info( | |
| "last device coords : %r\nlast local device coords: %r", | |
| get_coords(last_device), | |
| get_coords(last_input_local_devices), | |
| ) | |
| global_hardware_mesh = bounds_from_last_device(last_device) | |
| mesh_ndim = len(global_hardware_mesh) | |
| local_hardware_mesh = bounds_from_last_device(last_input_local_devices) | |
| mesh_err = ( | |
| f"each dimension of the model parallel submesh {model_parallel_submesh} " | |
| "must be a factor of the corresponding dimension of the global device " | |
| f"mesh {global_hardware_mesh}" | |
| ) | |
| assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err | |
| assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh)) | |
| devices = np.empty(global_hardware_mesh, dtype=object) | |
| for device in input_devices: | |
| device_coords = get_coords(device) | |
| devices[device_coords] = device | |
| tile_by_host = tile_by_host_if_needed | |
| if len(global_hardware_mesh) == 4: | |
| # enable contiguous local chunks without host tiling by making Z major | |
| global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh) | |
| model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh) | |
| gx, gy, gz, gc = global_hardware_mesh | |
| mx, my, mz, mc = model_parallel_submesh | |
| if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1): | |
| logging.info("ensuring YZ plane has a Z-major device order") | |
| # YZ should be ZY | |
| assert mc == gc, (mc, gc) | |
| global_hardware_mesh = gx, gz, gy, gc | |
| model_parallel_submesh = mx, mz, my, mc | |
| devices = devices.swapaxes(1, 2) | |
| tile_by_host = False | |
| if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1): | |
| logging.info("ensuring XZ plane has a Z-major device order") | |
| # XZ should be ZX | |
| assert mc == gc, (mc, gc) | |
| global_hardware_mesh = gz, gy, gx, gc | |
| model_parallel_submesh = mz, my, mx, mc | |
| devices = devices.swapaxes(0, 2) | |
| tile_by_host = False | |
| if tile_by_host: | |
| logging.warning( | |
| "Tiling device assignment mesh by hosts, which may lead to " | |
| "reduced XLA collective performance. To avoid this, modify " | |
| "the model parallel submesh or run with more tasks per host." | |
| ) | |
| tile_err = ( | |
| "to tile the mesh by hosts, each dimension of the model parallel " | |
| "submesh must be either a factor or a multiple of the corresponding " | |
| "dimension of the per-host submesh" | |
| ) | |
| def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]: | |
| """Split a global mesh dimension into four tiling components. | |
| Args: | |
| g: global mesh bounds dimension size | |
| m: model-parallel submesh bounds dimension size | |
| l: local submesh bounds dimension size | |
| Returns: | |
| The resulting tuple divides the dimension into the hosts component of | |
| the data-parallel submesh, the devices component of the data-parallel | |
| submesh, the hosts component of the model-parallel submesh, and the | |
| devices component of the model-parallel submesh. | |
| """ | |
| d = g // m | |
| if m >= l: | |
| assert not m % l, tile_err | |
| return (d, 1, m // l, l) | |
| else: | |
| assert not l % m, tile_err | |
| return (d // (l // m), l // m, 1, m) | |
| # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...] | |
| dh_dd_mh_md_tups = map(dh_dd_mh_md, global_hardware_mesh, model_parallel_submesh, local_hardware_mesh) | |
| # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...) | |
| devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension | |
| # TODO(jekbradbury): reorder local subgroups for ring locality | |
| # Transpose to [data_host], [data_device], [model_host], [model_device] | |
| # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...) | |
| devices = devices.transpose( | |
| *(4 * i for i in range(mesh_ndim)), | |
| *(4 * i + 1 for i in range(mesh_ndim)), | |
| *(4 * i + 2 for i in range(mesh_ndim)), | |
| *(4 * i + 3 for i in range(mesh_ndim)), | |
| ) | |
| else: | |
| # e.g. [(x_data, x_model), (y_data, y_model), ...] | |
| model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)] | |
| # reshape to e.g. (x_data, x_model, y_data, y_model...) | |
| devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension | |
| # TODO(jekbradbury): reorder small subgroups for ring locality | |
| # transpose to e.g. (x_data, y_data, ..., x_model, ...) | |
| devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim))) | |
| # reshape to (data, model) | |
| devices = devices.reshape(-1, np.prod(model_parallel_submesh)) | |
| global_mesh = Mesh(devices, ["data", "model"]) | |
| logging.info("global_mesh axis_names: %s", global_mesh.axis_names) | |
| logging.info("global_mesh devices: %s", global_mesh.devices) | |
| logging.info("global_mesh devices shape: %s", global_mesh.devices.shape) | |
| return global_mesh | |
| def get_cpu_mesh() -> Mesh: | |
| """Trivial mesh for CPU Testing.""" | |
| devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object) | |
| for device in jax.devices(): | |
| devices[device.process_index, device.id % jax.local_device_count()] = device | |
| return Mesh(devices, ["data", "model"]) | |
| def get_gpu_mesh(num_partitions: int) -> Mesh: | |
| """Mesh for GPUs that preferentially places 'model' on NVLink.""" | |
| nvlink_size = jax.local_device_count() | |
| dcn_size = jax.process_count() | |
| nvlink_mp = min(num_partitions, nvlink_size) | |
| nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp) | |
| dcn_mp, extra2 = divmod(num_partitions, nvlink_mp) | |
| assert not (extra1 or extra2), ( | |
| "number of partitions on GPU must be a factor" " or multiple of the number of local devices" | |
| ) | |
| dcn_dp = dcn_size // dcn_mp | |
| devices = create_hybrid_device_mesh( | |
| mesh_shape=[nvlink_dp, nvlink_mp], dcn_mesh_shape=[dcn_dp, dcn_mp], process_is_granule=True | |
| ) | |
| global_mesh = Mesh(devices, ["data", "model"]) | |
| logging.info("global_mesh axis_names: %s", global_mesh.axis_names) | |
| logging.info("global_mesh devices: %s", global_mesh.devices) | |
| return global_mesh | |
| def default_mesh( | |
| num_partitions: int, model_parallel_submesh: Optional[HardwareMesh] = None, backend: Optional[str] = None | |
| ) -> Mesh: | |
| """Attempt to return a default mesh for simple cases. | |
| Args: | |
| num_partitions: number of partitions to use, will be ignored if | |
| model_parallel_submesh is provided. | |
| model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as | |
| the model-parallel device tile. | |
| backend: get devices from the pinned backend, if specified. This is useful | |
| for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| Returns: | |
| xmap/pjit 2D Mesh with 'data', 'model' mesh axes. | |
| """ | |
| last_device = jax.devices(backend)[-1] | |
| platform = last_device.platform | |
| device_kind = last_device.device_kind | |
| bounds = bounds_from_last_device(last_device) | |
| if model_parallel_submesh: | |
| return get_mesh(model_parallel_submesh, backend=backend) | |
| if platform == "cpu": | |
| return get_cpu_mesh() | |
| elif platform == "gpu": | |
| return get_gpu_mesh(num_partitions) | |
| mps = None | |
| if device_kind in ("TPU v2", "TPU v3"): | |
| if num_partitions == 1: | |
| mps = (1, 1, 1, 1) | |
| elif num_partitions == 2: | |
| mps = (1, 1, 1, 2) | |
| elif num_partitions == 4: | |
| mps = (2, 1, 1, 2) | |
| elif num_partitions == 8: | |
| mps = (2, 2, 1, 2) | |
| elif num_partitions == 16: | |
| mps = (4, 2, 1, 2) | |
| # assume the use of megacore on TPU v4 | |
| elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1: | |
| if num_partitions == 1: | |
| mps = (1, 1, 1, 1) | |
| elif num_partitions == 2: | |
| mps = (1, 2, 1, 1) | |
| elif num_partitions == 4: | |
| if bounds[0] >= 4: | |
| mps = (4, 1, 1, 1) | |
| else: | |
| mps = (2, 2, 1, 1) | |
| elif num_partitions == 8: | |
| if bounds[2] >= 8: | |
| mps = (1, 1, 8, 1) | |
| else: | |
| mps = (4, 2, 1, 1) | |
| elif num_partitions == 16: | |
| if bounds[2] >= 16: | |
| mps = (1, 1, 16, 1) | |
| elif bounds[0] >= 8: | |
| mps = (8, 2, 1, 1) | |
| elif bounds[0] >= 4: | |
| mps = (4, 4, 1, 1) | |
| else: | |
| mps = (2, 2, 4, 1) | |
| if mps is None: | |
| raise ValueError( | |
| "No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly." | |
| ) | |
| return get_mesh(mps, backend=backend) | |
| # Data chunking helper. | |
| # ----------------------------------------------------------------------------- | |
| class LocalChunkInfo: | |
| # The logical slice of an array located on this host's local devices. | |
| slice: Tuple[slice, ...] | |
| # A unique index for this host/local chunk among chunks with the same slice. | |
| replica_id: int | |
| class LocalChunker: | |
| """Utility class to aid chunking of sharded arrays in multihost settings.""" | |
| def __init__(self, global_mesh: Mesh): | |
| self.global_mesh = global_mesh | |
| local_mesh = global_mesh.local_mesh | |
| first_local_device = local_mesh.devices.reshape(-1)[0] | |
| host_location = collections.OrderedDict( | |
| zip(global_mesh.shape.keys(), list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0]) | |
| ) | |
| self.num_chunks = collections.OrderedDict() | |
| self.chunk_ids = collections.OrderedDict() | |
| self.mesh_axes = list(global_mesh.shape.keys()) | |
| for mesh_axis in self.mesh_axes: | |
| num_devices_per_chunk = local_mesh.shape[mesh_axis] | |
| self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk | |
| self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk | |
| def get_local_chunk_info( | |
| self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]] | |
| ) -> LocalChunkInfo: | |
| """Get the local chunk info for a given array shape and sharded axes. | |
| Args: | |
| global_shape: the global, unsharded shape of the array to chunk. | |
| mesh_axes: a sequence of names (or None) of equal rank to `global_shape` | |
| that specifies which mesh dimensions the array is sharded along. | |
| Returns: | |
| LocalChunkInfo containing the logical slices of the array found on this | |
| host's local devices, as well as the replica index for this chunk among | |
| chunks with the same slice. The latter is used to determine which | |
| host should write this chunk during checkpointing. | |
| """ | |
| local_slice = [slice(None) for dim in global_shape] | |
| sharded_mesh_axes = set() | |
| for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)): | |
| if not mesh_axis: | |
| continue | |
| sharded_mesh_axes.add(mesh_axis) | |
| if not isinstance(mesh_axis, str): | |
| raise NotImplementedError("TODO(jekbradbury)") | |
| chunk_id = self.chunk_ids[mesh_axis] | |
| chunk_size = size // self.num_chunks[mesh_axis] | |
| local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size) | |
| replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes] | |
| replica_id = 0 | |
| for mesh_axis in replicated_mesh_axes: | |
| chunk_id = self.chunk_ids[mesh_axis] | |
| replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id | |
| return LocalChunkInfo(tuple(local_slice), replica_id) | |
| def standard_logical_axis_rules( | |
| activation_partitioning_dims: int = 1, | |
| parameter_partitioning_dims: int = 1, | |
| additional_rules: Optional[LogicalAxisRules] = None, | |
| ) -> LogicalAxisRules: | |
| """Default sharding rules for T5X model in terms of logical axis names. | |
| Args: | |
| activation_partitioning_dims: enables 2-D activation sharding when set to 2. | |
| parameter_partitioning_dims: enables 2-D parameter sharding when set to 2. | |
| additional_rules: additional rules (a sequence of tuples) that will be | |
| appended to the standard rules. | |
| Returns: | |
| Sequence of logical axis rules | |
| """ | |
| logging.info( | |
| "`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d", | |
| activation_partitioning_dims, | |
| parameter_partitioning_dims, | |
| ) | |
| if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1: | |
| rules = [ | |
| ("batch", "data"), | |
| ("vocab", "model"), | |
| ("embed", None), | |
| ("mlp", "model"), | |
| ("heads", "model"), | |
| ("kv", None), | |
| ("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts | |
| ] | |
| elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1: | |
| rules = [ | |
| ("batch", "data"), | |
| ("vocab", "model"), | |
| ("mlp", "model"), | |
| ("heads", "model"), | |
| ("kv", None), | |
| ("joined_kv", "model"), | |
| ("embed", "model"), | |
| ] | |
| elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2: | |
| rules = [ | |
| ("batch", "data"), | |
| ("vocab", "model"), | |
| ("mlp", "model"), | |
| ("heads", "model"), | |
| ("kv", None), | |
| ("joined_kv", "model"), | |
| ("embed", "data"), | |
| ] | |
| elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2: | |
| rules = [ | |
| ("batch", "data"), | |
| ("vocab", "model"), | |
| ("mlp", "model"), | |
| ("heads", "model"), | |
| ("kv", None), | |
| ("joined_kv", "model"), | |
| ("embed", "model"), | |
| ("embed", "data"), | |
| ] | |
| else: | |
| raise ValueError( | |
| f"`activation_partitioning_dims` = {activation_partitioning_dims} " | |
| f"`parameter_partitioning_dims` = {parameter_partitioning_dims} " | |
| "is not supported." | |
| ) | |
| # Add the common rules for the replicated logical axes names. | |
| replicated_rules = [ | |
| ("relpos_buckets", None), | |
| ("abspos_buckets", None), | |
| ("length", None), | |
| ("layers", None), | |
| ("stack", None), | |
| ("mlp_activations", None), | |
| ] | |
| rules.extend(replicated_rules) | |
| if additional_rules: | |
| rules.extend(additional_rules) | |
| return rules | |
| # NB: This needs to be top-level for the jax compilation cache. | |
| def _id_fn(x, ix): | |
| """Identity function for copying parameters to the devices, sharded.""" | |
| # A pure identity such as `lambda x, *: x` can get optimized away, so we | |
| # include a random.split as a cheap function that cannot be optimized away. | |
| y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32))) | |
| return x, y | |
| class DataLayout: | |
| """Represents data layout for the partitioned model.""" | |
| batch_size: int | |
| shard_id: int | |
| num_shards: int | |
| is_first_host_in_replica_set: bool | |
| PartitionedCallable = Callable[..., Any] | |
| CompiledPartitionedCallable = Callable[..., Any] | |
| class BasePartitioner(metaclass=abc.ABCMeta): | |
| """Interface for partitioning computations across hardware devices.""" | |
| def __init__( | |
| self, | |
| num_partitions: Optional[int] = None, | |
| model_parallel_submesh: Optional[HardwareMesh] = None, | |
| params_on_devices: bool = True, | |
| backend: Optional[str] = None, | |
| ): | |
| """Configures the partitioner. | |
| Args: | |
| num_partitions: the number of partitions to use. Ignored if | |
| `model_parallel_submesh` is provided. | |
| model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use | |
| as the model-parallel device tile. This submesh is used for the larger | |
| of the two parameter dimensions, and, if 2-D activation sharding is | |
| enabled, for the model dimension of activations. The rest of the mesh is | |
| used for data parallelism and, if 2-D parameter sharding is enabled, the | |
| other parameter dimension. | |
| params_on_devices: whether to keep the params on devices, if False - | |
| params stay in the host memory. Note that some partitioners might ignore | |
| this setting, for example if they don't support storing all params on | |
| device memory. | |
| backend: get devices from the pinned backend, if specified. This is useful | |
| for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| """ | |
| if not num_partitions and not model_parallel_submesh: | |
| raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.") | |
| if model_parallel_submesh is not None and len(model_parallel_submesh) != 4: | |
| logging.error( | |
| ( | |
| "`model_parallel_submesh` must be either None or a 4-tuple. Got" | |
| " `model_parallel_submesh`=%s. A ValueError will be raised" | |
| " beginning March 1, 2022." | |
| ), | |
| model_parallel_submesh, | |
| ) | |
| if bool(num_partitions) and bool(model_parallel_submesh): | |
| logging.error( | |
| "At most one of `num_partitions` or `model_parallel_submesh` can be " | |
| "set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A " | |
| "ValueError will be raised beginning March 21, 2022.", | |
| num_partitions, | |
| model_parallel_submesh, | |
| ) | |
| self._num_partitions = num_partitions | |
| self._model_parallel_submesh = model_parallel_submesh | |
| self._params_on_devices = params_on_devices | |
| self._data_axis = "data" | |
| self._backend = backend | |
| def mesh(self) -> Mesh: | |
| raise NotImplementedError | |
| def data_partition_spec(self) -> PartitionSpec: | |
| return PartitionSpec(self._data_axis) | |
| def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout: | |
| """Returns filled `DataLayout` based on the partitioned model layout. | |
| Args: | |
| batch_size: if set, indicates the requested batch size. The exception will | |
| be raised if this batch size is not compatible with the layout. If not | |
| set, the batch size is inferred from the layout. | |
| host_index: indicates the host index to use for the calculations, if not | |
| set - use JAX-provided one. Should be in [0, num_hosts) interval and the | |
| order should match the order of corresponding CPU devices in | |
| `jax.devices()`. | |
| Returns: | |
| Filled `DataLayout` structure. | |
| """ | |
| if host_index is not None: | |
| raise NotImplementedError("Explicit host_index is not yet implemented.") | |
| if self._data_axis is None: | |
| return DataLayout( | |
| batch_size=batch_size, | |
| shard_id=0, | |
| num_shards=1, | |
| is_first_host_in_replica_set=(jax.process_index() == 0), | |
| ) | |
| mesh_size = self._local_chunker.global_mesh.shape[self._data_axis] | |
| batch_size = batch_size or mesh_size | |
| if batch_size % mesh_size: | |
| raise ValueError( | |
| f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})." | |
| ) | |
| num_shards = self._local_chunker.num_chunks[self._data_axis] | |
| if batch_size % num_shards: | |
| raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).") | |
| replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id | |
| return DataLayout( | |
| batch_size=int(batch_size), | |
| shard_id=int(self._local_chunker.chunk_ids[self._data_axis]), | |
| num_shards=int(num_shards), | |
| is_first_host_in_replica_set=(replica_id == 0), | |
| ) | |
| def get_local_chunk_info( | |
| self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]] | |
| ) -> LocalChunkInfo: | |
| """Returns the local chunk info for a given array shape and sharded axes.""" | |
| return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes) | |
| def params_on_devices(self): | |
| return self._params_on_devices | |
| def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState: | |
| """Moves the optimizer parameters to devices.""" | |
| p_id_fn = self.partition( | |
| _id_fn, | |
| in_axis_resources=(train_state_axes, None), | |
| out_axis_resources=(train_state_axes, None), | |
| donate_argnums=(0,), | |
| ) | |
| if jax.config.jax_array and jax.process_count() > 1: | |
| train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes) | |
| train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32)) | |
| return train_state | |
| def _local_chunker(self): | |
| """Returns the chunker that matches the parameters of this partitioner.""" | |
| raise NotImplementedError | |
| def get_logical_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" | |
| # By default, return None for the logical axes. | |
| return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict())) | |
| def get_mesh_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" | |
| raise NotImplementedError | |
| def partition( | |
| self, | |
| fn: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = (), | |
| ) -> PartitionedCallable: | |
| """Partitions the computation using partitioner-specific implementation. | |
| Args: | |
| fn: the function to partition. | |
| in_axis_resources: Pytree of structure matching that of arguments to `fn`, | |
| with all actual arguments replaced by resource assignment | |
| specifications. It is also valid to specify a pytree prefix (e.g. one | |
| value in place of a whole subtree), in which case the leaves get | |
| broadcast to all values in that subtree. | |
| The valid resource assignment specifications are: | |
| `None`: in which case the value will be replicated on all devices | |
| `PartitionSpec`: a tuple of length at most equal to the rank of the | |
| partitioned value. Each element can be a `None`, a mesh axis or a | |
| tuple of mesh axes, and specifies the set of resources assigned to | |
| partition the value's dimension matching its position in the spec. | |
| out_axis_resources: Like `in_axis_resources`, but specifies resource | |
| assignment for function outputs. | |
| static_argnums: an optional int or collection of ints that specify which | |
| positional arguments to treat as static (compile-time constant) in the | |
| partitioned function. | |
| donate_argnums: an optional int or collection of ints that specify which | |
| argument buffers are "donated" to the computation. It is safe to donate | |
| argument buffers if you no longer need them once the computation has | |
| finished. | |
| Returns: | |
| A partitioned version of the input function. | |
| """ | |
| raise NotImplementedError | |
| def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable: | |
| """Compiles and returns the partitioned function, or the original. | |
| Args: | |
| partitioned_fn: The partitioned function. | |
| *args: Sample arguments to the partitioned function matching the input | |
| shapes that will be passed to the compiled function. | |
| Returns: | |
| The compiled function, or the original if this partitioner does not | |
| support compilation. | |
| """ | |
| raise NotImplementedError | |
| class PjittedFnWithContext(PartitionedCallable): | |
| """Wraps pjitted function to apply the appropriate contexts.""" | |
| def __init__(self, pjitted_fn, partition_mesh: Mesh, logical_axis_rules: flax_partitioning.LogicalRules = ()): | |
| self._pjitted_fn = pjitted_fn | |
| self._mesh = partition_mesh | |
| self._logical_axis_rules = logical_axis_rules | |
| def __call__(self, *args): | |
| with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules): | |
| return self._pjitted_fn(*args) | |
| def lower(self, *args): | |
| with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules): | |
| return self._pjitted_fn.lower(*args) | |
| class BasePjitPartitioner(BasePartitioner): | |
| """Partitioner that uses T5X version of jax.pjit.""" | |
| def _local_chunker(self) -> LocalChunker: | |
| return LocalChunker(self.mesh) | |
| def mesh(self) -> Mesh: | |
| return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend) | |
| def partition( | |
| self, | |
| fn: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = (), | |
| ) -> PjittedFnWithContext: | |
| pjitted = pjit( | |
| fn, | |
| in_axis_resources=in_axis_resources, | |
| out_axis_resources=out_axis_resources, | |
| static_argnums=static_argnums, | |
| donate_argnums=donate_argnums, | |
| backend=self._backend, | |
| ) | |
| return PjittedFnWithContext(pjitted, self.mesh) | |
| def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable: | |
| return partitioned_fn.lower(*args).compile() | |
| class PjitPartitioner(BasePjitPartitioner): | |
| """Partitioner that uses named axes and jax.pjit.""" | |
| def __init__( | |
| self, | |
| num_partitions: Optional[int] = None, | |
| model_parallel_submesh: Optional[HardwareMesh] = None, | |
| params_on_devices: bool = True, | |
| backend: Optional[str] = None, | |
| logical_axis_rules: Optional[LogicalAxisRules] = None, | |
| use_cpu_pjit: Optional[bool] = False, | |
| ): | |
| """PjitPartitioner constructor. | |
| See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details. | |
| Args: | |
| num_partitions: an integer that specifies the size of the model parallel | |
| submesh to be automatically selected for the current topology. See | |
| `model_parallel_submesh` for details on how this submesh is used. | |
| Mutually exlusive with `model_parallel_submesh`. | |
| model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)` | |
| submesh model-parallel device tile, an axis of accelerator parallelism | |
| orthogonal to data parallelism. Array axes in a model's parameters or | |
| activations can be sharded over this submesh using axis rules (see | |
| `logical_axis_rules`) that map them to 'model'. The effective number of | |
| model sub-partitions is equal to `np.prod(model_parallel_submesh)` and | |
| must evenly divide the total number of devices (i.e., | |
| `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest | |
| of the TPU mesh is the data parallel submesh, providing | |
| `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It | |
| is used for data (batch) parallelism and to shard other array axes that | |
| are mapped to 'data'. This argument is mutually exclusive with | |
| `num_partitions`. | |
| params_on_devices: whether to keep the params on devices, if False - | |
| params stay in the host memory. Note that some partitioners might ignore | |
| this setting, for example if they don't support storing all params on | |
| device memory. | |
| backend: get devices from the pinned backend, if specified. This is | |
| useful for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| logical_axis_rules: a priority-ordered sequence of KV tuples that maps | |
| logical axis names to either `None` (not sharded), 'model' (to shard | |
| across the model-parallel submesh), or 'data' (to shard across the | |
| data-parallel submesh). | |
| use_cpu_pjit: enables wrapper function for pjit which just jits the | |
| function if using CPU backend. | |
| """ | |
| super().__init__( | |
| num_partitions=num_partitions, | |
| model_parallel_submesh=model_parallel_submesh, | |
| params_on_devices=params_on_devices, | |
| backend=backend, | |
| ) | |
| if logical_axis_rules is None: | |
| logical_axis_rules = standard_logical_axis_rules() | |
| self._logical_axis_rules = tuple(logical_axis_rules) | |
| (self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules) | |
| self._use_cpu_pjit = use_cpu_pjit | |
| def partition( | |
| self, | |
| fn: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = (), | |
| ) -> PjittedFnWithContext: | |
| """Partitions the function using jax.pjit.""" | |
| if self._use_cpu_pjit: | |
| pjit_fn = pjit_with_cpu_fallback | |
| else: | |
| pjit_fn = pjit | |
| pjitted = pjit_fn( | |
| fn, | |
| in_axis_resources=in_axis_resources, | |
| out_axis_resources=out_axis_resources, | |
| static_argnums=static_argnums, | |
| donate_argnums=donate_argnums, | |
| backend=self._backend, | |
| ) | |
| return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules) | |
| def logical_axis_rules(self): | |
| """Returns the logical axis rules.""" | |
| return self._logical_axis_rules | |
| def get_logical_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" | |
| return train_state.as_logical_axes() | |
| def get_mesh_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" | |
| logical_axes = self.get_logical_axes(train_state) | |
| def _logical_to_mesh_axes(param_name, logical_axes): | |
| if logical_axes is None: | |
| return None | |
| elif logical_axes is traverse_util.empty_node: | |
| return traverse_util.empty_node | |
| try: | |
| return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules) | |
| except ValueError as e: | |
| raise ValueError(f"Failed to map logical axes for {param_name}") from e | |
| flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/") | |
| flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()} | |
| return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/")) | |