functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Decompositions request

Open zou3519 opened this issue 3 years ago • 1 comments

We need decompositions for the forward and backward passes of:

  • [x] nn.functional.nll_loss (we already have a decomposition for nll_loss backward in vmap's batching rule, please check it supports all cases and copy-paste into python) @samdow
  • [x] nn.functional.l1_loss @samdow (@Chillee added l1_loss forward, so just need backward).
  • [x] nn.functional.batch_norm (check with @Chillee, there's a PR open for this) ~[ ] nn.functional.instance_norm~ Uses batch norm so doesn't have a derivatives.yaml entry
  • [ ] fix nn.functional.embedding decomp: docs and decomp have different signatures. We don't really care about the sparse or renormalization features though
  • [x] binary_cross_entropy (forward)

Motivation

Forward-mode AD coverage. We're going to use the decomposition to understand what the backward passes of these things do. Then, we'll either:

  1. write a forward-mode AD formula for the backward pass
  2. Use the decomposition instead of writing the forward-mode AD formula to get forward-over-reverse coverage.

Both of (1) and (2) are not-too-difficult to do after we've got the decomposition written.

Instructions

Add the decomposition into https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.py Make sure the decomposition tests in test_ops.py pass.

Also, if you're adding a decomposition for the forward pass of an operator (like l1_loss), please make sure the decomposition has error checks for valid inputs. OpInfos do not test invalid infos, so we're flying blind w.r.t. testing here.

NB

There is an in-flight PR that is moving the functorch decompositions over to PyTorch: https://github.com/pytorch/pytorch/pull/76311. That might not include these decompositions we're adding now. Functorch has better testing (for now at least), but after https://github.com/pytorch/pytorch/pull/76311 goes in we should move whatever decompositions we wrote here over there.

zou3519 avatar Apr 26 '22 16:04 zou3519

I added a decomposition for aten.native_batch_norm (forward) - following up with the PR on backwards.

Chillee avatar Apr 29 '22 02:04 Chillee