pytorch
pytorch copied to clipboard
[ROCm] CK Flash Attention Backend
Replaces https://github.com/ROCm/pytorch/pull/1592
Updated implementation of CK gemm backend. Can close previous PR
This PR will generate CK kernels necessary for flash attention. Currently they will be generated in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.
Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/138947
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:x: 8 New Failures, 1 Unrelated Failure
As of commit ed20e391acab3559ec22b0c9d72e845a89462ff1 with merge base a17ecd8668e7ada0dbf6540e7c2cc153ba15ea7b ():
NEW FAILURES - The following jobs have failed:
- Lint / lintrunner-noclang / linux-job (gh)
>>> Lint for aten/src/ATen/CMakeLists.txt: - macos-arm64-binary-libtorch-cxx11-abi / libtorch-cpu-shared-with-deps-cxx11-abi-build (gh)
No files were found with the provided path: /Users/runner/work/_temp/artifacts. No artifacts will be uploaded. - macos-arm64-binary-wheel / wheel-py3_10-cpu-build (gh)
No files were found with the provided path: /Users/runner/work/_temp/artifacts. No artifacts will be uploaded. - macos-arm64-binary-wheel / wheel-py3_11-cpu-build (gh)
No files were found with the provided path: /Users/runner/work/_temp/artifacts. No artifacts will be uploaded. - macos-arm64-binary-wheel / wheel-py3_12-cpu-build (gh)
No files were found with the provided path: /Users/runner/work/_temp/artifacts. No artifacts will be uploaded. - macos-arm64-binary-wheel / wheel-py3_13-cpu-build (gh)
No files were found with the provided path: /Users/runner/work/_temp/artifacts. No artifacts will be uploaded. - macos-arm64-binary-wheel / wheel-py3_9-cpu-build (gh)
No files were found with the provided path: /Users/runner/work/_temp/artifacts. No artifacts will be uploaded. - windows-binary-wheel / wheel-py3_12-cpu-test (gh)
Process completed with exit code 1.
FLAKY - The following job failed but was likely due to flakiness present on trunk:
- rocm / linux-focal-rocm6.2-py3.10 / test (default, 4, 6, linux.rocm.gpu.2) (gh) (disabled by #141458 but the issue was closed recently and a rebase is needed to make it pass)
test_linalg.py::TestLinalgCUDA::test_matmul_small_brute_force_tunableop_cuda_float16
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Can you higlight the coverage and performance difference between this and the aottriton based version?
@zjing14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@alugorey This is a pretty substantive PR, esp. since it introduced a new CK backend for FA. Can you please capture the salient points in the PR description eg. changes to PyTorch build, relevant env vars if any, default settings, how new dependencies are installed, functional coverage, perf comparison vs AOTriton (if available) etc.?
Please check the SDPA UT failure - FAILED [0.3226s] inductor/test_fused_attention.py::SDPAPatternRewriterCudaTests::test_sdpa_rewriter_10_cuda - AssertionError: wrong number of dimensions
This PR is very timely, thanks! Please rebase past merge conflict, add performance numbers and fix the unit test. We'd like it to get merged asap
@xw285cornell rebased the PR (as well as squashed them as all those commits became unwieldy during rebasing) and fixed the unit test errors. The fixes for those failures are in the second of the two commits. I believe I just incorrectly copied some code from a stale version of pytorch. when i updated the file to what pytorch already had all failing tests passed. I will update again once i have performance numbers.
@xw285cornell Here are performance numbers via https://github.com/pytorch/pytorch/blob/main/benchmarks/transformer/sdpa.py
AOTRITON: +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 21.79451020201668 | 450.4763724980876 | | 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 20.076411101035774 | 417.15383948758245 | | 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 34.47094999719411 | 585.7575905974956 | | 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 25.168272096198052 | 536.6671353112906 | | 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 47.10616450756788 | 839.1657844185828 | | 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 56.71655246987939 | 800.2008404582739 | | 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 78.82721989881247 | 1465.2727136854082 | | 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 70.62346709426492 | 1246.6456997208297 | | 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 25.342435389757153 | 523.7751454114914 | | 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 21.52420529164374 | 426.74367886502296 | | 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 37.55110305501149 | 1003.9826692081989 | | 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 34.903176489751786 | 806.637704372406 | | 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 108.28629089519382 | 2677.0579942967743 | | 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 77.14254152961075 | 2057.310835225507 | | 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 340.4519024770707 | 7974.930899217727 | | 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 335.1199545431882 | 6494.933494832368 | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
CK: +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 17.66616189852357 | 187.1795894112438 | | 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 29.742619290482253 | 186.7577244993299 | | 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 30.34165904391557 | 186.68619252275676 | | 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 17.309418448712677 | 187.05361546017232 | | 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 31.12762535456568 | 187.47946387156844 | | 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 29.813363507855687 | 188.3397719357163 | | 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 57.14489193633199 | 188.11975000426173 | | 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 53.80507651716471 | 190.57486508972946 | | 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 30.706075858324763 | 185.7592734741047 | | 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 29.88582495599985 | 186.2289763521403 | | 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 30.88476264383644 | 188.25193995144218 | | 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 30.64896785654128 | 311.2350357696414 | | 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 67.67512508668005 | 338.631704216823 | | 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 72.19493587035686 | 348.0469754431396 | | 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 226.54105187393725 | 905.5793448351324 | | 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 253.85335448663685 | 925.6914001889527 | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
Please let me know if you'd like anything else.
@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.
@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.
Looks like the snapshot of CK here is from October 23rd. Quite a few changes have gone in since then. I will update CK and repost performance to see if that helps.
@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.
Looks like the snapshot of CK here is from October 23rd. Quite a few changes have gone in since then. I will update CK and repost performance to see if that helps.
@xw285cornell , We propose to merge the current PR and take up the CK update in subsequent PR.
It seems as though the CK backend is universally better than the AOTriton version. Can we just fully replace the AOTriton implementation with CK so that we dont have to maintain both code paths?
@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.
Looks like the snapshot of CK here is from October 23rd. Quite a few changes have gone in since then. I will update CK and repost performance to see if that helps.
This was trickier than anticipated. Bunch of linker errors in the new kernels generated. We will follow up at a later date w.r.t. updating CK.
It's ok to keep this PR as is (please fix all the unit tests and populate the files - unless @drisspg disagrees :) ) But let's certainly follow up updating CK in the near future, in the end we need good perf (and internally we use a different CK version - so we cannot land this PR if it breaks with internal.
Hi @albanD , can we move around the line limit of 2000 for this PR?
@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
There are a few failing tests. And for the generated files, please create a folder aten/src/ATen/native/transformers/hip/flash_attn/ck/instances. Or maybe at least move them into aten/src/ATen/native/transformers/hip/flash_attn/ck.
Also it seems we mix .hpp and .h - can we just use one? e.g. flash_api.h and flash_common_hip.hpp
There are a few failing tests. And for the generated files, please create a folder aten/src/ATen/native/transformers/hip/flash_attn/ck/instances. Or maybe at least move them into aten/src/ATen/native/transformers/hip/flash_attn/ck.
Also it seems we mix .hpp and .h - can we just use one? e.g. flash_api.h and flash_common_hip.hpp
Sure. Going to stick with .hpp for now as that is what the generated files expect
alright I guess we cannot change to hpp because the hipify from cuda requires flash_api.h. Probably change it back, not worth it...
Deploy Preview for chimerical-cranachan-793287 ready!
| Name | Link |
|---|---|
| Latest commit | ca4e3e2d73b888b3f08d7eb08a383475bf308dfb |
| Latest deploy log | https://app.netlify.com/sites/chimerical-cranachan-793287/deploys/67538b8e211ecc00083caf79 |
| Deploy Preview | https://deploy-preview-138947--chimerical-cranachan-793287.netlify.app |
| Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.
Converted to draft to prevent this going into release branch. Will be picked up after release branch is cut.
@malfet Looks like the binary builds for ROCm will break due to this PR with the current changes, because CK doesn't build successfully for all gfx architectures. Our original plan was to note the increase in binary build times on the assumption that it would build successfully, but that assumption is turning out to be incorrect. We are currently exploring the other option we discussed, whereby we integrate a prebuilt CK as a shared object into PyTorch, but that will take a few days probably.
I don't think we want to leave ROCm nightlies broken for an extended period of time while we work on the .so solution for CK, so I'm holding off merging this PR. Let me know if you have different thoughts.
cc @xw285cornell @jeffdaily @pruthvistony
@malfet Looks like the binary builds for ROCm will break due to this PR with the current changes, because CK doesn't build successfully for all gfx architectures.
Can't you guard off building kernels for certain architectures, using something like #ifdef __HIP_ARCHITECTURE == XYZ?
@malfet Looks like the binary builds for ROCm will break due to this PR with the current changes, because CK doesn't build successfully for all gfx architectures. Our original plan was to note the increase in binary build times on the assumption that it would build successfully, but that assumption is turning out to be incorrect. We are currently exploring the other option we discussed, whereby we integrate a prebuilt CK as a shared object into PyTorch, but that will take a few days probably.
I don't think we want to leave ROCm nightlies broken for an extended period of time while we work on the .so solution for CK, so I'm holding off merging this PR. Let me know if you have different thoughts.
cc @xw285cornell @jeffdaily @pruthvistony
FWIW, A nice middle-ground could be to force it to compile only for the supported architectures (hard-code the --offload-arch flags), then only switch to the CK backend if the current GPU arch is one of the supported ones.
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased rocm_ck_sdpa onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout rocm_ck_sdpa && git pull --rebase)
@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
sounds good, I'll work on this internally and land it from there. Don't land from here because it's going to break