|
|
|
|
|
|
|
|
|
import pickle |
|
from collections import OrderedDict |
|
|
|
import torch |
|
from torch import distributed as dist |
|
from torch import nn |
|
|
|
from .dist import _get_global_gloo_group, get_world_size |
|
|
|
ASYNC_NORM = ( |
|
nn.BatchNorm1d, |
|
nn.BatchNorm2d, |
|
nn.BatchNorm3d, |
|
nn.InstanceNorm1d, |
|
nn.InstanceNorm2d, |
|
nn.InstanceNorm3d, |
|
) |
|
|
|
__all__ = [ |
|
"get_async_norm_states", |
|
"pyobj2tensor", |
|
"tensor2pyobj", |
|
"all_reduce", |
|
"all_reduce_norm", |
|
] |
|
|
|
|
|
def get_async_norm_states(module): |
|
async_norm_states = OrderedDict() |
|
for name, child in module.named_modules(): |
|
if isinstance(child, ASYNC_NORM): |
|
for k, v in child.state_dict().items(): |
|
async_norm_states[".".join([name, k])] = v |
|
return async_norm_states |
|
|
|
|
|
def pyobj2tensor(pyobj, device="cuda"): |
|
"""serialize picklable python object to tensor""" |
|
storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) |
|
return torch.ByteTensor(storage).to(device=device) |
|
|
|
|
|
def tensor2pyobj(tensor): |
|
"""deserialize tensor to picklable python object""" |
|
return pickle.loads(tensor.cpu().numpy().tobytes()) |
|
|
|
|
|
def _get_reduce_op(op_name): |
|
return { |
|
"sum": dist.ReduceOp.SUM, |
|
"mean": dist.ReduceOp.SUM, |
|
}[op_name.lower()] |
|
|
|
|
|
def all_reduce(py_dict, op="sum", group=None): |
|
""" |
|
Apply all reduce function for python dict object. |
|
NOTE: make sure that every py_dict has the same keys and values are in the same shape. |
|
|
|
Args: |
|
py_dict (dict): dict to apply all reduce op. |
|
op (str): operator, could be "sum" or "mean". |
|
""" |
|
world_size = get_world_size() |
|
if world_size == 1: |
|
return py_dict |
|
if group is None: |
|
group = _get_global_gloo_group() |
|
if dist.get_world_size(group) == 1: |
|
return py_dict |
|
|
|
|
|
py_key = list(py_dict.keys()) |
|
py_key_tensor = pyobj2tensor(py_key) |
|
dist.broadcast(py_key_tensor, src=0) |
|
py_key = tensor2pyobj(py_key_tensor) |
|
|
|
tensor_shapes = [py_dict[k].shape for k in py_key] |
|
tensor_numels = [py_dict[k].numel() for k in py_key] |
|
|
|
flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) |
|
dist.all_reduce(flatten_tensor, op=_get_reduce_op(op)) |
|
if op == "mean": |
|
flatten_tensor /= world_size |
|
|
|
split_tensors = [ |
|
x.reshape(shape) |
|
for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes) |
|
] |
|
return OrderedDict({k: v for k, v in zip(py_key, split_tensors)}) |
|
|
|
|
|
def all_reduce_norm(module): |
|
""" |
|
All reduce norm statistics in different devices. |
|
""" |
|
states = get_async_norm_states(module) |
|
states = all_reduce(states, op="mean") |
|
module.load_state_dict(states, strict=False) |
|
|