|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributed.tensor import ( |
|
DeviceMesh, |
|
distribute_module, |
|
distribute_tensor, |
|
DTensor, |
|
Replicate, |
|
Shard, |
|
) |
|
from torch.distributed.tensor.parallel import ParallelStyle |
|
from torch.distributed.tensor.placement_types import Placement |
|
|
|
|
|
|
|
class TensorParallel(ParallelStyle): |
|
def __init__( |
|
self, |
|
*, |
|
input_layouts: Optional[Tuple[Optional[Placement]]] = None, |
|
output_layout: Optional[Placement] = None, |
|
use_local_output: bool = True, |
|
): |
|
super().__init__() |
|
self.input_layouts = input_layouts or (Replicate(), Replicate()) |
|
self.output_layout = output_layout or Replicate() |
|
self.desired_input_layouts = (Replicate(), Replicate()) |
|
self.use_local_output = use_local_output |
|
|
|
@staticmethod |
|
def _prepare_input_fn( |
|
input_layouts, desired_input_layouts, mod, inputs, device_mesh |
|
): |
|
prepared_inputs = [] |
|
|
|
for inp, input_layout, desired_input_layout in zip( |
|
inputs, input_layouts, desired_input_layouts |
|
): |
|
if isinstance(inp, torch.Tensor): |
|
if not isinstance(inp, DTensor): |
|
inp = DTensor.from_local( |
|
inp, device_mesh, (input_layout,), run_check=False |
|
) |
|
if input_layout != desired_input_layout: |
|
inp = inp.redistribute( |
|
placements=(desired_input_layout,), async_op=True |
|
) |
|
prepared_inputs.append(inp) |
|
return tuple(prepared_inputs) |
|
|
|
def _partition_fn(self, name, module, device_mesh): |
|
module.register_parameter( |
|
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) |
|
) |
|
module.register_parameter( |
|
"w2", |
|
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), |
|
) |
|
module.register_parameter( |
|
"w3", |
|
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), |
|
) |
|
|
|
@staticmethod |
|
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): |
|
if outputs.placements != (output_layout,): |
|
outputs = outputs.redistribute(placements=(output_layout,), async_op=True) |
|
|
|
return outputs.to_local() if use_local_output else outputs |
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
|
return distribute_module( |
|
module, |
|
device_mesh, |
|
self._partition_fn, |
|
partial( |
|
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts |
|
), |
|
partial(self._prepare_output_fn, self.output_layout, self.use_local_output), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NoParallel(ParallelStyle): |
|
def __init__( |
|
self, |
|
*, |
|
input_layout: Optional[Placement] = None, |
|
output_layout: Optional[Placement] = None, |
|
use_local_output: bool = True, |
|
): |
|
super().__init__() |
|
self.input_layout = input_layout or Replicate() |
|
self.output_layout = output_layout or Replicate() |
|
self.desired_input_layout = Replicate() |
|
self.use_local_output = use_local_output |
|
|
|
@staticmethod |
|
def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): |
|
|
|
input_tensor = inputs[0] |
|
if not isinstance(input_tensor, DTensor): |
|
input_tensor = DTensor.from_local( |
|
input_tensor, device_mesh, (input_layout,), run_check=False |
|
) |
|
|
|
if input_layout != desired_input_layout: |
|
input_tensor = input_tensor.redistribute( |
|
placements=(desired_input_layout,), async_op=True |
|
) |
|
return (input_tensor, *inputs[1:]) |
|
|
|
@staticmethod |
|
def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): |
|
if outputs.placements != (output_layout,): |
|
outputs = outputs.redistribute(placements=(output_layout,), async_op=True) |
|
|
|
return outputs.to_local() if use_local_output else outputs |
|
|
|
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
|
return distribute_module( |
|
module, |
|
device_mesh, |
|
None, |
|
partial( |
|
self._prepare_input_fn, self.input_layout, self.desired_input_layout |
|
), |
|
partial(self._prepare_output_fn, self.output_layout, self.use_local_output), |
|
) |
|
|