functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Use Python aten -> torch decompositions to register "rules" for functorch transforms

Open zou3519 opened this issue 2 years ago • 1 comments

Motivation

We've received quite a bit of user feedback that PyTorch's forward-mode AD coverage isn't that good. While one thing we can do to address that is to grind out more forward-mode AD formulas, another way to tackle the problem is to use our pre-existing decompositions: if an operator like log_softmax_backward doesn't have a forward-mode AD rule but it does have a decomposition, then we can run the decomposition to achieve forward-mode AD support for log_softmax_backward.

This also applies to vmap coverage as well: if we don't have a batching rule for something, but if we happen to have a decomposition for it, then using the decomposition is a really fast way to get coverage. The downside to a decomposition is overhead; to go the full mile we would want to actually write a batching rule.

Problems

  1. We need a way to call the Python decompositions from C++. functorch has decompositions in Python (https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.py). However, vmap and forward-mode AD can only be tweaked with from the C++ side.
  2. Once we can access the decomposition from C++, there needs to be some mechanism to actually get the subsystems to use the decomposition. For vmap this is simply just registering the decomposition to the FuncTorchBatched key, for forward-mode AD this is going to be a bit trickier because the PyTorch Autograd key combines both reverse-mode and forward-mode AD.
  3. Testing: we want some serious testing that the decomposition does match the semantics of the operation that is being decomposed.

Potential solutions

For calling Python decompositions from C++:

  • @eellison @Chillee have been mentioning something to me about getting the decompositions into TorchScript and then making them usable from C++
  • Another alternative is to call directly back into Python from C++.

zou3519 avatar Apr 21 '22 18:04 zou3519

Cool! Yea most of the work has been done here. I have an example of running a decomposition in C++ here, with the decomposition defined here.

I'm going to shortly clean up the api for getting/invoking the decomposition and add a hook to register decompositions so that we can run the decompositions that are already in functorch now before migrating them over to core eventually.

PR: https://github.com/pytorch/functorch/pull/740/files

eellison avatar Apr 22 '22 17:04 eellison