apex
apex copied to clipboard
[transformer] Use `torch.distributed._all_gather_base`
Pros:
_all_gather_base
has fewer device to device memory copies than all_gather
.
all_gather
does auxiliary DtoD mem copies in https://github.com/pytorch/pytorch/blob/653892e288b750217dcb7bf4f95ad6c63d3a487d/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1851-L1863.
Cons:
_all_gather_base
has been marked as experimental: https://github.com/pytorch/pytorch/blob/653892e288b750217dcb7bf4f95ad6c63d3a487d/torch/distributed/distributed_c10d.py#L2109-L2112.
Ref:
-
_all_gather_base
impl: https://github.com/pytorch/pytorch/blob/b447fa3912e69e16b37ec9619324964e74a7078a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L2413 -
all_gather
impl: https://github.com/pytorch/pytorch/blob/653892e288b750217dcb7bf4f95ad6c63d3a487d/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1811