heat
heat copied to clipboard
Implement AD functionality for MPI operations
Feature functionality It is planned to implement a thin provisiong layer above the existing MPI communication API that allows for nearly seemless integration of the MPI calls with the PyTorch AD system.
Additional context The rationale to implement AD is that it is one of the distingushing features of every ML library in comparison to a linear algebra library.
There are several issues that need to be resolved in order to provide an AD-capable MPI-like API. The PyTorch AD machinery, or AD itself, has strong assumptions on the functions one can calculate a derivative of.
- [ ] The AD mechanism expects functions to be mostly pure functions, especially all variable input shall be passed as arguments, and all output shall be returned from the function. MPI function calls rather often have
recvbuf
arguments that violate the later part.MPI.IN_PLACE
functionality may also be difficult to support. - [ ] In PyTorch's AD implementation the AD functionality is achieved using a directed acyclic graph (DAG), within which all edges represent a PyTorch Tensor that is exchanged. It needs to be a Tensor and cannot be a general data type, since a bifurcation in forward mode becomes a summation in backward mode, thus requiring the exchanged object to belong to some group with a reasonable addition defined on it. Of course, the latter is e.g. not the case for
MPIRequest
instances. As such, to make all asynchronous MPI calls AD-capable, and to properly reflect the causal dependency of the async MPI call and the corresponding Wait call in the DAG, one probably needs to encapsulate theMPIRequest
s in atorch.tensor
, as it has been done in the prototype. - [ ] Other synchronization issues are also to be expected, since the DAG looses the ordering of the original source code in which the MPI calls were executed. As such, additional synchronization primitives are necessary to reflect the temporal ordering of calls in the DAG, thus avoiding potential dead locks and race conditions in the backward pass.
Introducing "student project" label for potential thesis work.
Reviewed within #1109
Suggestion for a "prototype"
A good start could be to implement sth like diffable_Allreduce_SUM
, i.e. a AD-compatible counterpart of ht.comm.Allreduce( ... , op=MPI.SUM)
... this is maybe the simplemost case since (at least in my understanding) the structure of this function does not cause problems when reversing the direction of the DAG that would need to be catched with dummy-constructions etc.
I would suggest to try to define a custom autograd'able function diffable_Allreduce_SUM
with corresponding forward()
, backward()
and setup_context()
(?) as described here:
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html and in particular: https://pytorch.org/docs/stable/notes/extending.html
Regarding the derivatives (i.e. backward()
) for Allreduce
(with MPI.SUM
), see the great work of @d1saster https://github.com/helmholtz-analytics/mpi4torch/blob/master/csrc/extension.cpp. In fact, my suggestion for this issue is to progress similarly in Python as what has been done in C++ for mpi4torch
.
Next Steps after Prototype
-
Allgather
andAlltoall
should not infer problems with the computational graph as well (at least in my opinion) and therefore might be handeled in a similar way - could also be nice to try out and think about (because not done so far in
mpi4torch
:Allreduce
withMPI.PROD
orMPI.MAX/MIN
(?)
Where it gets tricky
in principle this is extensively discussed in the docs of mpi4torch
and the above comment of @d1saster
- every operation that destroys the structure of the DAG in its backward pass --- unfortunately: most of the operations...