functorch
functorch copied to clipboard
Use Python aten -> torch decompositions to register "rules" for functorch transforms
Motivation
We've received quite a bit of user feedback that PyTorch's forward-mode AD coverage isn't that good. While one thing we can do to address that is to grind out more forward-mode AD formulas, another way to tackle the problem is to use our pre-existing decompositions: if an operator like log_softmax_backward
doesn't have a forward-mode AD rule but it does have a decomposition, then we can run the decomposition to achieve forward-mode AD support for log_softmax_backward
.
This also applies to vmap coverage as well: if we don't have a batching rule for something, but if we happen to have a decomposition for it, then using the decomposition is a really fast way to get coverage. The downside to a decomposition is overhead; to go the full mile we would want to actually write a batching rule.
Problems
- We need a way to call the Python decompositions from C++. functorch has decompositions in Python (https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.py). However, vmap and forward-mode AD can only be tweaked with from the C++ side.
- Once we can access the decomposition from C++, there needs to be some mechanism to actually get the subsystems to use the decomposition. For vmap this is simply just registering the decomposition to the FuncTorchBatched key, for forward-mode AD this is going to be a bit trickier because the PyTorch Autograd key combines both reverse-mode and forward-mode AD.
- Testing: we want some serious testing that the decomposition does match the semantics of the operation that is being decomposed.
Potential solutions
For calling Python decompositions from C++:
- @eellison @Chillee have been mentioning something to me about getting the decompositions into TorchScript and then making them usable from C++
- Another alternative is to call directly back into Python from C++.
Cool! Yea most of the work has been done here. I have an example of running a decomposition in C++ here, with the decomposition defined here.
I'm going to shortly clean up the api for getting/invoking the decomposition and add a hook to register decompositions so that we can run the decompositions that are already in functorch now before migrating them over to core eventually.
PR: https://github.com/pytorch/functorch/pull/740/files