pytorch
pytorch copied to clipboard
Batch Norm Consolidation
Stack from ghstack (oldest at bottom):
- #119496
- -> #116092
Summary:
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
batch_norm_with_update. The existing hierarchy looks like:
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
_batch_norm_cuda kernel instead of the cudnn_batch_norm
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating existing batch norm ops will look like:
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
The new op batch_norm_with_update hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
Note that this commit only adds this op and its variants, but does not actually change the decomps to produce these ops in the graph. This will be done after the 2 week FC window, and the ops used in the old stack is planned to be removed after the 6 month BC window.
Test Plan: OpInfo tests for batch_norm_with_update.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/116092
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: You can merge normally! (1 Unrelated Failure)
As of commit e3546b5b7e62c62879704542b0c9f6d62e5a90a8 with merge base f2f8eeea944f5cc6dd6f907a0c78067f73e0ad9c ():
FLAKY - The following job failed but was likely due to flakiness present on trunk:
- inductor / rocm6.0-py3.8-inductor / test (inductor, 1, 1, linux.rocm.gpu.2) (gh)
test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited
This comment was automatically generated by Dr. CI and updates every 15 minutes.
might be related to consolidation in the frontend sense:
- https://github.com/pytorch/pytorch/issues/41243
- https://github.com/pytorch/pytorch/issues/66073
Overall looks good! I mentioned that a dummy PR to test that the new decomposition logic for batchnorm works might be a good idea, but two other tests that I think would be good in this PR are:
(1) A test that uses torch.export with batchnorm, (with and without training), expecting on what the graph looks like
(2) A test that uses torch.compile with batchnorm, showing that (with the decomposition) we DCE the unused cudnn tensor
Although I guess (1) is hard to test just in this PR, since I guess we'd need to change the at::batch_norm decomp to actually go to the new op first.
If you want to save it for after the FC window I think that's probably ok - although another option is to feature flag it, and just flip the feature flag in your test.
Would this by any chance fix the difference in saved_mean and saved_rstd of batch_norm when training is False in cpu vs cuda?
https://github.com/pytorch/pytorch/blob/7ad4ab4765f52cc917fdc1b587f5f6e6d3175cad/torch/_decomp/decompositions.py#L1657-L1663
It confuses downstream device agnostic operator conversions on which shape the converter should respect. Fortunately, they are mostly unused outputs under this specific code path, but still it would be nicer if this discrepancy can be resolved.
Would this by any chance fix the difference in saved_mean and saved_rstd of batch_norm when training is False in cpu vs cuda?
https://github.com/pytorch/pytorch/blob/7ad4ab4765f52cc917fdc1b587f5f6e6d3175cad/torch/_decomp/decompositions.py#L1657-L1663
It confuses downstream device agnostic operator conversions on which shape the converter should respect. Fortunately, they are mostly unused outputs under this specific code path, but still it would be nicer if this discrepancy can be resolved.
Hi @BowenBao, I don't think this PR fixes this unfortunately. Here we're mostly just trying to preserve the existing behavior with fewer ops and we're reusing the same decompositions as before. Is there an existing issue for this? If not, please feel free to file one and we can take a look separately.
@andrewor14 here is the issue for reference https://github.com/pytorch/pytorch/issues/100985
@BowenBao I think fixing this would be an independent problem from the one here. The change here faithfully preserves the current API and only shuffles around the implementation. To fix that issue, we would need to do a BC-breaking change.
@albanD it would be better for decomps to not use cudnn, inductor can pattern match to cudnn if it needs to.
@pytorchbot merge
Merge started
Your change will be merged once all checks pass (ETA 0-4 Hours).
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here
The merge job was canceled. If you believe this is a mistake, then you can re trigger it through pytorch-bot.
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased gh/andrewor14/48/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/116092)
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Tried to rebase and push PR #116092, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main
@pytorchbot merge
Merge started
Your change will be merged once all checks pass (ETA 0-4 Hours).
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here
@pytorchbot revert
❌ 🤖 pytorchbot command failed:
@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification
usage: @pytorchbot revert -m MESSAGE -c
{nosignal,ignoredsignal,landrace,weird,ghfirst}
Try @pytorchbot --help for more info.
@pytorchbot revert -m "broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't" -c weird
@pytorchbot successfully started a revert job. Check the current status here. Questions? Feedback? Please reach out to the PyTorch DevX Team
@andrewor14 your PR has been successfully reverted.
@pytorchbot revert -m "broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't" -c weird
Hi @jeffdaily, could you point me to the ROCm error so I can take a look?
@andrewor14 this link should take you to the trunk commit summary for when the PR was landed. You'll see 3 rocm failures.
https://hud.pytorch.org/pytorch/pytorch/commit/5680f565d5b7d4aa412a3988d3d91ca4c5679303
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased gh/andrewor14/48/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/116092)