pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Batch Norm Consolidation

Open andrewor14 opened this issue 1 year ago • 5 comments

Stack from ghstack (oldest at bottom):

  • #119496
  • -> #116092

Summary:

This commit simplifies the existing decomposition hierarchy of batch norm ops by adding a single, backend agnostic op: batch_norm_with_update. The existing hierarchy looks like:

aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]

Aside from complexity, an important problem with the above decomposition hierarchy is cuda numerics in export flows. We observed significantly worse convergence when training a mobilenetv2-like model when using the _batch_norm_cuda kernel instead of the cudnn_batch_norm kernel. This means users who export their models on CPU first then move the models to cuda later may silently see worse accuracies even when cudnn is installed, because they are using the worse kernel. This issue is summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating existing batch norm ops will look like:

aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]

The new op batch_norm_with_update hides backend implementation details and automatically picks the right kernel based on what is installed. This commit also adds the following variants to this op:

batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward

Note that this commit only adds this op and its variants, but does not actually change the decomps to produce these ops in the graph. This will be done after the 2 week FC window, and the ops used in the old stack is planned to be removed after the 6 month BC window.

Test Plan: OpInfo tests for batch_norm_with_update.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

andrewor14 avatar Dec 19 '23 09:12 andrewor14

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/116092

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: You can merge normally! (1 Unrelated Failure)

As of commit e3546b5b7e62c62879704542b0c9f6d62e5a90a8 with merge base f2f8eeea944f5cc6dd6f907a0c78067f73e0ad9c (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Dec 19 '23 09:12 pytorch-bot[bot]

might be related to consolidation in the frontend sense:

  • https://github.com/pytorch/pytorch/issues/41243
  • https://github.com/pytorch/pytorch/issues/66073

vadimkantorov avatar Dec 19 '23 10:12 vadimkantorov

Overall looks good! I mentioned that a dummy PR to test that the new decomposition logic for batchnorm works might be a good idea, but two other tests that I think would be good in this PR are:

(1) A test that uses torch.export with batchnorm, (with and without training), expecting on what the graph looks like

(2) A test that uses torch.compile with batchnorm, showing that (with the decomposition) we DCE the unused cudnn tensor

bdhirsh avatar Feb 06 '24 23:02 bdhirsh

Although I guess (1) is hard to test just in this PR, since I guess we'd need to change the at::batch_norm decomp to actually go to the new op first.

If you want to save it for after the FC window I think that's probably ok - although another option is to feature flag it, and just flip the feature flag in your test.

bdhirsh avatar Feb 06 '24 23:02 bdhirsh

Would this by any chance fix the difference in saved_mean and saved_rstd of batch_norm when training is False in cpu vs cuda?

https://github.com/pytorch/pytorch/blob/7ad4ab4765f52cc917fdc1b587f5f6e6d3175cad/torch/_decomp/decompositions.py#L1657-L1663

It confuses downstream device agnostic operator conversions on which shape the converter should respect. Fortunately, they are mostly unused outputs under this specific code path, but still it would be nicer if this discrepancy can be resolved.

BowenBao avatar Feb 16 '24 22:02 BowenBao

Would this by any chance fix the difference in saved_mean and saved_rstd of batch_norm when training is False in cpu vs cuda?

https://github.com/pytorch/pytorch/blob/7ad4ab4765f52cc917fdc1b587f5f6e6d3175cad/torch/_decomp/decompositions.py#L1657-L1663

It confuses downstream device agnostic operator conversions on which shape the converter should respect. Fortunately, they are mostly unused outputs under this specific code path, but still it would be nicer if this discrepancy can be resolved.

Hi @BowenBao, I don't think this PR fixes this unfortunately. Here we're mostly just trying to preserve the existing behavior with fewer ops and we're reusing the same decompositions as before. Is there an existing issue for this? If not, please feel free to file one and we can take a look separately.

andrewor14 avatar Feb 20 '24 16:02 andrewor14

@andrewor14 here is the issue for reference https://github.com/pytorch/pytorch/issues/100985

BowenBao avatar Feb 20 '24 17:02 BowenBao

@BowenBao I think fixing this would be an independent problem from the one here. The change here faithfully preserves the current API and only shuffles around the implementation. To fix that issue, we would need to do a BC-breaking change.

albanD avatar Feb 20 '24 17:02 albanD

@albanD it would be better for decomps to not use cudnn, inductor can pattern match to cudnn if it needs to.

jansel avatar Feb 21 '24 02:02 jansel

@pytorchbot merge

andrewor14 avatar Mar 05 '24 02:03 andrewor14

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Mar 05 '24 02:03 pytorchmergebot

The merge job was canceled. If you believe this is a mistake, then you can re trigger it through pytorch-bot.

pytorchmergebot avatar Mar 05 '24 08:03 pytorchmergebot

@pytorchbot rebase

andrewor14 avatar Mar 05 '24 15:03 andrewor14

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

pytorchmergebot avatar Mar 05 '24 15:03 pytorchmergebot

Successfully rebased gh/andrewor14/48/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/116092)

pytorchmergebot avatar Mar 05 '24 15:03 pytorchmergebot

@pytorchbot rebase

andrewor14 avatar Mar 05 '24 19:03 andrewor14

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

pytorchmergebot avatar Mar 05 '24 19:03 pytorchmergebot

Tried to rebase and push PR #116092, but it was already up to date. Try rebasing against main by issuing: @pytorchbot rebase -b main

pytorchmergebot avatar Mar 05 '24 19:03 pytorchmergebot

@pytorchbot merge

andrewor14 avatar Mar 06 '24 04:03 andrewor14

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Mar 06 '24 04:03 pytorchmergebot

@pytorchbot revert

jeffdaily avatar Mar 06 '24 17:03 jeffdaily

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

pytorch-bot[bot] avatar Mar 06 '24 17:03 pytorch-bot[bot]

@pytorchbot revert -m "broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't" -c weird

jeffdaily avatar Mar 06 '24 17:03 jeffdaily

@pytorchbot successfully started a revert job. Check the current status here. Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot avatar Mar 06 '24 17:03 pytorchmergebot

@andrewor14 your PR has been successfully reverted.

pytorchmergebot avatar Mar 06 '24 17:03 pytorchmergebot

@pytorchbot revert -m "broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't" -c weird

Hi @jeffdaily, could you point me to the ROCm error so I can take a look?

andrewor14 avatar Mar 06 '24 19:03 andrewor14

@andrewor14 this link should take you to the trunk commit summary for when the PR was landed. You'll see 3 rocm failures.

https://hud.pytorch.org/pytorch/pytorch/commit/5680f565d5b7d4aa412a3988d3d91ca4c5679303

jeffdaily avatar Mar 06 '24 20:03 jeffdaily

@pytorchbot rebase

andrewor14 avatar Mar 08 '24 00:03 andrewor14

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

pytorchmergebot avatar Mar 08 '24 00:03 pytorchmergebot

Successfully rebased gh/andrewor14/48/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/116092)

pytorchmergebot avatar Mar 08 '24 00:03 pytorchmergebot