pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

[ROCm] CK Flash Attention Backend

Open alugorey opened this issue 1 year ago • 20 comments
trafficstars

Replaces https://github.com/ROCm/pytorch/pull/1592

Updated implementation of CK gemm backend. Can close previous PR

This PR will generate CK kernels necessary for flash attention. Currently they will be generated in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

alugorey avatar Oct 25 '24 21:10 alugorey

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/138947

Note: Links to docs will display an error until the docs builds have been completed.

:x: 8 New Failures, 1 Unrelated Failure

As of commit ed20e391acab3559ec22b0c9d72e845a89462ff1 with merge base a17ecd8668e7ada0dbf6540e7c2cc153ba15ea7b (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 25 '24 21:10 pytorch-bot[bot]

Can you higlight the coverage and performance difference between this and the aottriton based version?

drisspg avatar Oct 29 '24 01:10 drisspg

@zjing14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar Oct 30 '24 00:10 facebook-github-bot

@alugorey This is a pretty substantive PR, esp. since it introduced a new CK backend for FA. Can you please capture the salient points in the PR description eg. changes to PyTorch build, relevant env vars if any, default settings, how new dependencies are installed, functional coverage, perf comparison vs AOTriton (if available) etc.?

jithunnair-amd avatar Oct 30 '24 02:10 jithunnair-amd

Please check the SDPA UT failure - FAILED [0.3226s] inductor/test_fused_attention.py::SDPAPatternRewriterCudaTests::test_sdpa_rewriter_10_cuda - AssertionError: wrong number of dimensions

pruthvistony avatar Oct 31 '24 04:10 pruthvistony

This PR is very timely, thanks! Please rebase past merge conflict, add performance numbers and fix the unit test. We'd like it to get merged asap

xw285cornell avatar Nov 01 '24 15:11 xw285cornell

@xw285cornell rebased the PR (as well as squashed them as all those commits became unwieldy during rebasing) and fixed the unit test errors. The fixes for those failures are in the second of the two commits. I believe I just incorrectly copied some code from a stale version of pytorch. when i updated the file to what pytorch already had all failing tests passed. I will update again once i have performance numbers.

alugorey avatar Nov 05 '24 19:11 alugorey

@xw285cornell Here are performance numbers via https://github.com/pytorch/pytorch/blob/main/benchmarks/transformer/sdpa.py

AOTRITON: +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 21.79451020201668 | 450.4763724980876 | | 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 20.076411101035774 | 417.15383948758245 | | 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 34.47094999719411 | 585.7575905974956 | | 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 25.168272096198052 | 536.6671353112906 | | 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 47.10616450756788 | 839.1657844185828 | | 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 56.71655246987939 | 800.2008404582739 | | 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 78.82721989881247 | 1465.2727136854082 | | 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 70.62346709426492 | 1246.6456997208297 | | 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 25.342435389757153 | 523.7751454114914 | | 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 21.52420529164374 | 426.74367886502296 | | 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 37.55110305501149 | 1003.9826692081989 | | 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 34.903176489751786 | 806.637704372406 | | 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 108.28629089519382 | 2677.0579942967743 | | 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 77.14254152961075 | 2057.310835225507 | | 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 340.4519024770707 | 7974.930899217727 | | 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 335.1199545431882 | 6494.933494832368 | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+

CK: +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal | dtype | forward_time | backward_time | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+ | 1 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 17.66616189852357 | 187.1795894112438 | | 1 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 29.742619290482253 | 186.7577244993299 | | 1 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 30.34165904391557 | 186.68619252275676 | | 1 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 17.309418448712677 | 187.05361546017232 | | 1 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 31.12762535456568 | 187.47946387156844 | | 1 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 29.813363507855687 | 188.3397719357163 | | 1 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 57.14489193633199 | 188.11975000426173 | | 1 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 53.80507651716471 | 190.57486508972946 | | 8 | 16 | 128 | 128 | 2048 | True | torch.bfloat16 | 30.706075858324763 | 185.7592734741047 | | 8 | 16 | 128 | 128 | 2048 | False | torch.bfloat16 | 29.88582495599985 | 186.2289763521403 | | 8 | 16 | 256 | 256 | 2048 | True | torch.bfloat16 | 30.88476264383644 | 188.25193995144218 | | 8 | 16 | 256 | 256 | 2048 | False | torch.bfloat16 | 30.64896785654128 | 311.2350357696414 | | 8 | 16 | 512 | 512 | 2048 | True | torch.bfloat16 | 67.67512508668005 | 338.631704216823 | | 8 | 16 | 512 | 512 | 2048 | False | torch.bfloat16 | 72.19493587035686 | 348.0469754431396 | | 8 | 16 | 1024 | 1024 | 2048 | True | torch.bfloat16 | 226.54105187393725 | 905.5793448351324 | | 8 | 16 | 1024 | 1024 | 2048 | False | torch.bfloat16 | 253.85335448663685 | 925.6914001889527 | +------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+

Please let me know if you'd like anything else.

alugorey avatar Nov 05 '24 22:11 alugorey

@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.

xw285cornell avatar Nov 06 '24 03:11 xw285cornell

@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.

Looks like the snapshot of CK here is from October 23rd. Quite a few changes have gone in since then. I will update CK and repost performance to see if that helps.

alugorey avatar Nov 06 '24 16:11 alugorey

@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.

Looks like the snapshot of CK here is from October 23rd. Quite a few changes have gone in since then. I will update CK and repost performance to see if that helps.

@xw285cornell , We propose to merge the current PR and take up the CK update in subsequent PR.

pruthvistony avatar Nov 06 '24 17:11 pruthvistony

It seems as though the CK backend is universally better than the AOTriton version. Can we just fully replace the AOTriton implementation with CK so that we dont have to maintain both code paths?

drisspg avatar Nov 06 '24 17:11 drisspg

@alugorey thanks! it looks the backward perf is pretty bad? Do we have all the optimization in CK landed here? I thought there are some pretty good perf optimizations in CK in the past a few months.

Looks like the snapshot of CK here is from October 23rd. Quite a few changes have gone in since then. I will update CK and repost performance to see if that helps.

This was trickier than anticipated. Bunch of linker errors in the new kernels generated. We will follow up at a later date w.r.t. updating CK.

alugorey avatar Nov 06 '24 18:11 alugorey

It's ok to keep this PR as is (please fix all the unit tests and populate the files - unless @drisspg disagrees :) ) But let's certainly follow up updating CK in the near future, in the end we need good perf (and internally we use a different CK version - so we cannot land this PR if it breaks with internal.

xw285cornell avatar Nov 07 '24 17:11 xw285cornell

Hi @albanD , can we move around the line limit of 2000 for this PR?

alugorey avatar Nov 07 '24 19:11 alugorey

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar Nov 13 '24 02:11 facebook-github-bot

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar Nov 16 '24 01:11 facebook-github-bot

There are a few failing tests. And for the generated files, please create a folder aten/src/ATen/native/transformers/hip/flash_attn/ck/instances. Or maybe at least move them into aten/src/ATen/native/transformers/hip/flash_attn/ck.

Also it seems we mix .hpp and .h - can we just use one? e.g. flash_api.h and flash_common_hip.hpp

xw285cornell avatar Nov 16 '24 23:11 xw285cornell

There are a few failing tests. And for the generated files, please create a folder aten/src/ATen/native/transformers/hip/flash_attn/ck/instances. Or maybe at least move them into aten/src/ATen/native/transformers/hip/flash_attn/ck.

Also it seems we mix .hpp and .h - can we just use one? e.g. flash_api.h and flash_common_hip.hpp

Sure. Going to stick with .hpp for now as that is what the generated files expect

alugorey avatar Nov 18 '24 19:11 alugorey

alright I guess we cannot change to hpp because the hipify from cuda requires flash_api.h. Probably change it back, not worth it...

xw285cornell avatar Nov 19 '24 06:11 xw285cornell

Deploy Preview for chimerical-cranachan-793287 ready!

Name Link
Latest commit ca4e3e2d73b888b3f08d7eb08a383475bf308dfb
Latest deploy log https://app.netlify.com/sites/chimerical-cranachan-793287/deploys/67538b8e211ecc00083caf79
Deploy Preview https://deploy-preview-138947--chimerical-cranachan-793287.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

netlify[bot] avatar Dec 06 '24 18:12 netlify[bot]

Converted to draft to prevent this going into release branch. Will be picked up after release branch is cut.

alugorey avatar Dec 09 '24 16:12 alugorey

@malfet Looks like the binary builds for ROCm will break due to this PR with the current changes, because CK doesn't build successfully for all gfx architectures. Our original plan was to note the increase in binary build times on the assumption that it would build successfully, but that assumption is turning out to be incorrect. We are currently exploring the other option we discussed, whereby we integrate a prebuilt CK as a shared object into PyTorch, but that will take a few days probably.

I don't think we want to leave ROCm nightlies broken for an extended period of time while we work on the .so solution for CK, so I'm holding off merging this PR. Let me know if you have different thoughts.

cc @xw285cornell @jeffdaily @pruthvistony

jithunnair-amd avatar Dec 10 '24 16:12 jithunnair-amd

@malfet Looks like the binary builds for ROCm will break due to this PR with the current changes, because CK doesn't build successfully for all gfx architectures.

Can't you guard off building kernels for certain architectures, using something like #ifdef __HIP_ARCHITECTURE == XYZ?

malfet avatar Dec 10 '24 16:12 malfet

@malfet Looks like the binary builds for ROCm will break due to this PR with the current changes, because CK doesn't build successfully for all gfx architectures. Our original plan was to note the increase in binary build times on the assumption that it would build successfully, but that assumption is turning out to be incorrect. We are currently exploring the other option we discussed, whereby we integrate a prebuilt CK as a shared object into PyTorch, but that will take a few days probably.

I don't think we want to leave ROCm nightlies broken for an extended period of time while we work on the .so solution for CK, so I'm holding off merging this PR. Let me know if you have different thoughts.

cc @xw285cornell @jeffdaily @pruthvistony

FWIW, A nice middle-ground could be to force it to compile only for the supported architectures (hard-code the --offload-arch flags), then only switch to the CK backend if the current GPU arch is one of the supported ones.

jammm avatar Dec 10 '24 16:12 jammm

@pytorchbot rebase

pruthvistony avatar Dec 13 '24 03:12 pruthvistony

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

pytorchmergebot avatar Dec 13 '24 03:12 pytorchmergebot

Successfully rebased rocm_ck_sdpa onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout rocm_ck_sdpa && git pull --rebase)

pytorchmergebot avatar Dec 13 '24 03:12 pytorchmergebot

@xw285cornell has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot avatar Dec 14 '24 05:12 facebook-github-bot

sounds good, I'll work on this internally and land it from there. Don't land from here because it's going to break

xw285cornell avatar Dec 15 '24 08:12 xw285cornell