Decompositions request
We need decompositions for the forward and backward passes of:
- [x] nn.functional.nll_loss (we already have a decomposition for nll_loss backward in vmap's batching rule, please check it supports all cases and copy-paste into python) @samdow
- [x] nn.functional.l1_loss @samdow (@Chillee added l1_loss forward, so just need backward).
- [x] nn.functional.batch_norm (check with @Chillee, there's a PR open for this) ~[ ] nn.functional.instance_norm~ Uses batch norm so doesn't have a derivatives.yaml entry
- [ ] fix nn.functional.embedding decomp: docs and decomp have different signatures. We don't really care about the sparse or renormalization features though
- [x] binary_cross_entropy (forward)
Motivation
Forward-mode AD coverage. We're going to use the decomposition to understand what the backward passes of these things do. Then, we'll either:
- write a forward-mode AD formula for the backward pass
- Use the decomposition instead of writing the forward-mode AD formula to get forward-over-reverse coverage.
Both of (1) and (2) are not-too-difficult to do after we've got the decomposition written.
Instructions
Add the decomposition into https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.py Make sure the decomposition tests in test_ops.py pass.
Also, if you're adding a decomposition for the forward pass of an operator (like l1_loss), please make sure the decomposition has error checks for valid inputs. OpInfos do not test invalid infos, so we're flying blind w.r.t. testing here.
NB
There is an in-flight PR that is moving the functorch decompositions over to PyTorch: https://github.com/pytorch/pytorch/pull/76311. That might not include these decompositions we're adding now. Functorch has better testing (for now at least), but after https://github.com/pytorch/pytorch/pull/76311 goes in we should move whatever decompositions we wrote here over there.
I added a decomposition for aten.native_batch_norm (forward) - following up with the PR on backwards.