| import torch.distributed as dist | |
| _GLOBAL_PARALLEL_GROUPS = dict() | |
| def set_data_parallel_group(group: dist.ProcessGroup): | |
| _GLOBAL_PARALLEL_GROUPS["data"] = group | |
| def get_data_parallel_group(): | |
| return _GLOBAL_PARALLEL_GROUPS.get("data", None) | |
| def set_sequence_parallel_group(group: dist.ProcessGroup): | |
| _GLOBAL_PARALLEL_GROUPS["sequence"] = group | |
| def get_sequence_parallel_group(): | |
| return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) | |