heat icon indicating copy to clipboard operation
heat copied to clipboard

Implement AD functionality for MPI operations

Open d1saster opened this issue 5 years ago • 2 comments

Feature functionality It is planned to implement a thin provisiong layer above the existing MPI communication API that allows for nearly seemless integration of the MPI calls with the PyTorch AD system.

Additional context The rationale to implement AD is that it is one of the distingushing features of every ML library in comparison to a linear algebra library.

There are several issues that need to be resolved in order to provide an AD-capable MPI-like API. The PyTorch AD machinery, or AD itself, has strong assumptions on the functions one can calculate a derivative of.

  • [ ] The AD mechanism expects functions to be mostly pure functions, especially all variable input shall be passed as arguments, and all output shall be returned from the function. MPI function calls rather often have recvbuf arguments that violate the later part. MPI.IN_PLACE functionality may also be difficult to support.
  • [ ] In PyTorch's AD implementation the AD functionality is achieved using a directed acyclic graph (DAG), within which all edges represent a PyTorch Tensor that is exchanged. It needs to be a Tensor and cannot be a general data type, since a bifurcation in forward mode becomes a summation in backward mode, thus requiring the exchanged object to belong to some group with a reasonable addition defined on it. Of course, the latter is e.g. not the case for MPIRequest instances. As such, to make all asynchronous MPI calls AD-capable, and to properly reflect the causal dependency of the async MPI call and the corresponding Wait call in the DAG, one probably needs to encapsulate the MPIRequests in a torch.tensor, as it has been done in the prototype.
  • [ ] Other synchronization issues are also to be expected, since the DAG looses the ordering of the original source code in which the MPI calls were executed. As such, additional synchronization primitives are necessary to reflect the temporal ordering of calls in the DAG, thus avoiding potential dead locks and race conditions in the backward pass.

d1saster avatar Feb 10 '20 15:02 d1saster

Introducing "student project" label for potential thesis work.


Reviewed within #1109

ClaudiaComito avatar Aug 11 '23 04:08 ClaudiaComito

Suggestion for a "prototype" A good start could be to implement sth like diffable_Allreduce_SUM, i.e. a AD-compatible counterpart of ht.comm.Allreduce( ... , op=MPI.SUM) ... this is maybe the simplemost case since (at least in my understanding) the structure of this function does not cause problems when reversing the direction of the DAG that would need to be catched with dummy-constructions etc.

I would suggest to try to define a custom autograd'able function diffable_Allreduce_SUM with corresponding forward(), backward() and setup_context() (?) as described here:

https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html and in particular: https://pytorch.org/docs/stable/notes/extending.html

Regarding the derivatives (i.e. backward()) for Allreduce (with MPI.SUM), see the great work of @d1saster https://github.com/helmholtz-analytics/mpi4torch/blob/master/csrc/extension.cpp. In fact, my suggestion for this issue is to progress similarly in Python as what has been done in C++ for mpi4torch.

Next Steps after Prototype

  • Allgather and Alltoall should not infer problems with the computational graph as well (at least in my opinion) and therefore might be handeled in a similar way
  • could also be nice to try out and think about (because not done so far in mpi4torch: Allreduce with MPI.PROD or MPI.MAX/MIN (?)

Where it gets tricky in principle this is extensively discussed in the docs of mpi4torch and the above comment of @d1saster

  • every operation that destroys the structure of the DAG in its backward pass --- unfortunately: most of the operations...

mrfh92 avatar Aug 25 '23 14:08 mrfh92