functorch
functorch copied to clipboard
Figure out how to get test coverage for more compositions of transforms
Motivation
Currently, we only test the following compositions:
- vmap
- jvp
- vjp
- vmap x jvp
- vmap x vjp
- vjp x vjp
- vjp x vmap
This has caught most of our bugs, but users still come to us with code that doesn't work due to it not being one of the above compositions. For example:
- vmap x vmap can still error out even if just vmap works
- vmap x vjp x vjp can error out if there is some backward operator (e.g. convolution_backward) that has a backward formula that is not composite compliant. Ditto for vmap x jvp x vjp.
The Ask
Figure to get better test coverage for more compositions of transforms
Possibly related: OpInfos
This also is related to better OpInfo testing. OpInfos do not cover all aten operators. One way for us to really get good coverage using our existing tests is to add OpInfos for torch.ops.aten operations. For example, instead of checking the batching rule of torch.ops.aten.convolution_backward via a vmap x vjp test, it would be sufficient for us to just run a vmap test for torch.ops.aten.convolution_backward.